Compare commits

..

5 Commits

Author SHA1 Message Date
Lev Kokotov
61b9756ded unnecessary 2022-08-16 15:52:14 -07:00
Lev Kokotov
2cd9e15849 unused import 2022-08-16 15:47:54 -07:00
Lev Kokotov
fd57fae280 better avg calc 2022-08-16 15:46:23 -07:00
Lev Kokotov
a460a645f5 quick refactor 2022-08-16 15:34:36 -07:00
Lev Kokotov
f7d33fba7a Log stats-own generated events 2022-08-16 15:00:31 -07:00
17 changed files with 643 additions and 913 deletions

View File

@@ -108,7 +108,6 @@ servers = [
]
# Database name (e.g. "postgres")
database = "shard0"
search_path = "\"$user\",public"
[pools.sharded_db.shards.1]
servers = [

View File

@@ -3,9 +3,6 @@
set -e
set -o xtrace
# non-zero exit code if we provide bad configs
(! ./target/debug/pgcat "fake_configs" 2>/dev/null)
# Start PgCat with a particular log level
# for inspection.
function start_pgcat() {
@@ -22,8 +19,8 @@ PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard1 -i
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard2 -i
# Install Toxiproxy to simulate a downed/slow database
wget -O toxiproxy-2.4.0.deb https://github.com/Shopify/toxiproxy/releases/download/v2.4.0/toxiproxy_2.4.0_linux_$(dpkg --print-architecture).deb
sudo dpkg -i toxiproxy-2.4.0.deb
wget -O toxiproxy-2.1.4.deb https://github.com/Shopify/toxiproxy/releases/download/v2.1.4/toxiproxy_2.1.4_amd64.deb
sudo dpkg -i toxiproxy-2.1.4.deb
# Start Toxiproxy
toxiproxy-server &
@@ -132,14 +129,11 @@ toxiproxy-cli toxic remove --toxicName latency_downstream postgres_replica
start_pgcat "info"
# Test session mode (and config reload)
sed -i '0,/simple_db/s/pool_mode = "transaction"/pool_mode = "session"/' .circleci/pgcat.toml
sed -i 's/pool_mode = "transaction"/pool_mode = "session"/' .circleci/pgcat.toml
# Reload config test
kill -SIGHUP $(pgrep pgcat)
# Revert settings after reload. Makes test runs idempotent
sed -i '0,/simple_db/s/pool_mode = "session"/pool_mode = "transaction"/' .circleci/pgcat.toml
sleep 1
# Prepared statements that will only work in session mode

7
Cargo.lock generated
View File

@@ -159,12 +159,6 @@ dependencies = [
"termcolor",
]
[[package]]
name = "exitcode"
version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de853764b47027c2e862a995c34978ffa63c1501f2e15f987ba11bd4f9bba193"
[[package]]
name = "fnv"
version = "1.0.7"
@@ -521,7 +515,6 @@ dependencies = [
"bytes",
"chrono",
"env_logger",
"exitcode",
"hmac",
"hyper",
"log",

View File

@@ -33,4 +33,3 @@ tokio-rustls = "0.23"
rustls-pemfile = "1"
hyper = { version = "0.14", features = ["full"] }
phf = { version = "0.10", features = ["macros"] }
exitcode = "1.1.2"

View File

@@ -15,7 +15,7 @@ PostgreSQL pooler (like PgBouncer) with sharding, load balancing and failover su
| Session pooling | :white_check_mark: | Identical to PgBouncer. |
| `COPY` support | :white_check_mark: | Both `COPY TO` and `COPY FROM` are supported. |
| Query cancellation | :white_check_mark: | Supported both in transaction and session pooling modes. |
| Load balancing of read queries | :white_check_mark: | Using random between replicas. Primary is included when `primary_reads_enabled` is enabled (default). |
| Load balancing of read queries | :white_check_mark: | Using round-robin between replicas. Primary is included when `primary_reads_enabled` is enabled (default). |
| Sharding | :white_check_mark: | Transactions are sharded using `SET SHARD TO` and `SET SHARDING KEY TO` syntax extensions; see examples below. |
| Failover | :white_check_mark: | Replicas are tested with a health check. If a health check fails, remaining replicas are attempted; see below for algorithm description and examples. |
| Statistics | :white_check_mark: | Statistics available in the admin database (`pgcat` and `pgbouncer`) with `SHOW STATS`, `SHOW POOLS` and others. |
@@ -63,7 +63,7 @@ psql -h 127.0.0.1 -p 6432 -c 'SELECT 1'
| `database` | The name of the database to connect to. This is the same on all servers that are part of one shard. | |
| | | |
| **`query_router`** | | |
| `default_role` | Traffic is routed to this role by default (random), unless the client specifies otherwise. Default is `any`, for any role available. | `any`, `primary`, `replica` |
| `default_role` | Traffic is routed to this role by default (round-robin), unless the client specifies otherwise. Default is `any`, for any role available. | `any`, `primary`, `replica` |
| `query_parser_enabled` | Enable the query parser which will inspect incoming queries and route them to a primary or replicas. | `false` |
| `primary_reads_enabled` | Enable this to allow read queries on the primary; otherwise read queries are routed to the replicas. | `true` |
@@ -112,7 +112,7 @@ In transaction mode, a client talks to one server for the duration of a single t
This mode is enabled by default.
### Load balancing of read queries
All queries are load balanced against the configured servers using the random algorithm. The most straight forward configuration example would be to put this pooler in front of several replicas and let it load balance all queries.
All queries are load balanced against the configured servers using the round-robin algorithm. The most straight forward configuration example would be to put this pooler in front of several replicas and let it load balance all queries.
If the configuration includes a primary and replicas, the queries can be separated with the built-in query parser. The query parser will interpret the query and route all `SELECT` queries to a replica, while all other queries including explicit transactions will be routed to the primary.
@@ -151,18 +151,18 @@ Failover behavior can get pretty interesting (read complex) when multiple config
| **Query** | **`SET SERVER ROLE TO`** | **`query_parser_enabled`** | **`primary_reads_enabled`** | **Target state** | **Outcome** |
|---------------------------|--------------------------|----------------------------|-----------------------------|------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| Read query, i.e. `SELECT` | unset (any) | false | false | up | Query is routed to the first instance in the random loop. |
| Read query | unset (any) | true | false | up | Query is routed to the first replica instance in the random loop. |
| Read query | unset (any) | true | true | up | Query is routed to the first instance in the random loop. |
| Read query | replica | false | false | up | Query is routed to the first replica instance in the random loop. |
| Read query, i.e. `SELECT` | unset (any) | false | false | up | Query is routed to the first instance in the round-robin loop. |
| Read query | unset (any) | true | false | up | Query is routed to the first replica instance in the round-robin loop. |
| Read query | unset (any) | true | true | up | Query is routed to the first instance in the round-robin loop. |
| Read query | replica | false | false | up | Query is routed to the first replica instance in the round-robin loop. |
| Read query | primary | false | false | up | Query is routed to the primary. |
| Read query | unset (any) | false | false | down | First instance is banned for reads. Next target in the random loop is attempted. |
| Read query | unset (any) | true | false | down | First replica instance is banned. Next replica instance is attempted in the random loop. |
| Read query | unset (any) | true | true | down | First instance (even if primary) is banned for reads. Next instance is attempted in the random loop. |
| Read query | replica | false | false | down | First replica instance is banned. Next replica instance is attempted in the random loop. |
| Read query | unset (any) | false | false | down | First instance is banned for reads. Next target in the round-robin loop is attempted. |
| Read query | unset (any) | true | false | down | First replica instance is banned. Next replica instance is attempted in the round-robin loop. |
| Read query | unset (any) | true | true | down | First instance (even if primary) is banned for reads. Next instance is attempted in the round-robin loop. |
| Read query | replica | false | false | down | First replica instance is banned. Next replica instance is attempted in the round-robin loop. |
| Read query | primary | false | false | down | The query is attempted against the primary and fails. The client receives an error. |
| | | | | | |
| Write query e.g. `INSERT` | unset (any) | false | false | up | The query is attempted against the first available instance in the random loop. If the instance is a replica, the query fails and the client receives an error. |
| Write query e.g. `INSERT` | unset (any) | false | false | up | The query is attempted against the first available instance in the round-robin loop. If the instance is a replica, the query fails and the client receives an error. |
| Write query | unset (any) | true | false | up | The query is routed to the primary. |
| Write query | unset (any) | true | true | up | The query is routed to the primary. |
| Write query | primary | false | false | up | The query is routed to the primary. |

View File

@@ -265,11 +265,11 @@ where
for (_, pool) in get_all_pools() {
let pool_config = pool.settings.clone();
for shard in 0..pool.shards() {
let database_name = &pool.address(shard, 0).database;
let database_name = &pool_config.shards[&shard.to_string()].database;
for server in 0..pool.servers(shard) {
let address = pool.address(shard, server);
let pool_state = pool.pool_state(shard, server);
let banned = pool.is_banned(address, Some(address.role));
let banned = pool.is_banned(address, shard, Some(address.role));
res.put(data_row(&vec![
address.name(), // name

View File

@@ -5,14 +5,13 @@ use std::collections::HashMap;
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
use tokio::sync::broadcast::Receiver;
use tokio::sync::mpsc::Sender;
use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::config::{get_config, Address};
use crate::constants::*;
use crate::errors::Error;
use crate::messages::*;
use crate::pool::{get_pool, ClientServerMap, ConnectionPool, PoolMode};
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::query_router::{Command, QueryRouter};
use crate::server::Server;
use crate::stats::{get_reporter, Reporter};
@@ -59,6 +58,7 @@ pub struct Client<S, T> {
client_server_map: ClientServerMap,
/// Client parameters, e.g. user, client_encoding, etc.
#[allow(dead_code)]
parameters: HashMap<String, String>,
/// Statistics
@@ -73,26 +73,21 @@ pub struct Client<S, T> {
/// Last server process id we talked to.
last_server_id: Option<i32>,
/// Connected to server
connected_to_server: bool,
/// Name of the server pool for this client (This comes from the database name in the connection string)
pool_name: String,
target_pool_name: String,
/// Postgres user for this client (This comes from the user in the connection string)
username: String,
target_user_name: String,
/// Used to notify clients about an impending shutdown
shutdown: Receiver<()>,
shutdown_event_receiver: Receiver<()>,
}
/// Client entrypoint.
pub async fn client_entrypoint(
mut stream: TcpStream,
client_server_map: ClientServerMap,
shutdown: Receiver<()>,
drain: Sender<i8>,
admin_only: bool,
shutdown_event_receiver: Receiver<()>,
) -> Result<(), Error> {
// Figure out if the client wants TLS or not.
let addr = stream.peer_addr().unwrap();
@@ -111,21 +106,11 @@ pub async fn client_entrypoint(
write_all(&mut stream, yes).await?;
// Negotiate TLS.
match startup_tls(stream, client_server_map, shutdown, admin_only).await {
match startup_tls(stream, client_server_map, shutdown_event_receiver).await {
Ok(mut client) => {
info!("Client {:?} connected (TLS)", addr);
if !client.is_admin() {
let _ = drain.send(1).await;
}
let result = client.handle().await;
if !client.is_admin() {
let _ = drain.send(-1).await;
}
result
client.handle().await
}
Err(err) => Err(err),
}
@@ -151,25 +136,14 @@ pub async fn client_entrypoint(
addr,
bytes,
client_server_map,
shutdown,
admin_only,
shutdown_event_receiver,
)
.await
{
Ok(mut client) => {
info!("Client {:?} connected (plain)", addr);
if !client.is_admin() {
let _ = drain.send(1).await;
}
let result = client.handle().await;
if !client.is_admin() {
let _ = drain.send(-1).await;
}
result
client.handle().await
}
Err(err) => Err(err),
}
@@ -192,25 +166,14 @@ pub async fn client_entrypoint(
addr,
bytes,
client_server_map,
shutdown,
admin_only,
shutdown_event_receiver,
)
.await
{
Ok(mut client) => {
info!("Client {:?} connected (plain)", addr);
if client.is_admin() {
let _ = drain.send(1).await;
}
let result = client.handle().await;
if !client.is_admin() {
let _ = drain.send(-1).await;
}
result
client.handle().await
}
Err(err) => Err(err),
}
@@ -221,21 +184,20 @@ pub async fn client_entrypoint(
let (read, write) = split(stream);
// Continue with cancel query request.
match Client::cancel(read, write, addr, bytes, client_server_map, shutdown).await {
match Client::cancel(
read,
write,
addr,
bytes,
client_server_map,
shutdown_event_receiver,
)
.await
{
Ok(mut client) => {
info!("Client {:?} issued a cancel query request", addr);
if client.is_admin() {
let _ = drain.send(1).await;
}
let result = client.handle().await;
if !client.is_admin() {
let _ = drain.send(-1).await;
}
result
client.handle().await
}
Err(err) => Err(err),
@@ -288,8 +250,7 @@ where
pub async fn startup_tls(
stream: TcpStream,
client_server_map: ClientServerMap,
shutdown: Receiver<()>,
admin_only: bool,
shutdown_event_receiver: Receiver<()>,
) -> Result<Client<ReadHalf<TlsStream<TcpStream>>, WriteHalf<TlsStream<TcpStream>>>, Error> {
// Negotiate TLS.
let tls = Tls::new()?;
@@ -319,8 +280,7 @@ pub async fn startup_tls(
addr,
bytes,
client_server_map,
shutdown,
admin_only,
shutdown_event_receiver,
)
.await
}
@@ -335,10 +295,6 @@ where
S: tokio::io::AsyncRead + std::marker::Unpin,
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
pub fn is_admin(&self) -> bool {
self.admin
}
/// Handle Postgres client startup after TLS negotiation is complete
/// or over plain text.
pub async fn startup(
@@ -347,44 +303,29 @@ where
addr: std::net::SocketAddr,
bytes: BytesMut, // The rest of the startup message.
client_server_map: ClientServerMap,
shutdown: Receiver<()>,
admin_only: bool,
shutdown_event_receiver: Receiver<()>,
) -> Result<Client<S, T>, Error> {
let config = get_config();
let stats = get_reporter();
let parameters = parse_startup(bytes.clone())?;
// These two parameters are mandatory by the protocol.
let pool_name = match parameters.get("database") {
trace!("Got StartupMessage");
let parameters = parse_startup(bytes.clone())?;
let target_pool_name = match parameters.get("database") {
Some(db) => db,
None => return Err(Error::ClientError),
};
let username = match parameters.get("user") {
let target_user_name = match parameters.get("user") {
Some(user) => user,
None => return Err(Error::ClientError),
};
let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == &pool_name)
.filter(|db| *db == &target_pool_name)
.count()
== 1;
// Kick any client that's not admin while we're in admin-only mode.
if !admin && admin_only {
debug!(
"Rejecting non-admin connection to {} when in admin only mode",
pool_name
);
error_response_terminal(
&mut write,
&format!("terminating connection due to administrator command"),
)
.await?;
return Err(Error::ShuttingDown);
}
// Generate random backend ID and secret key
let process_id: i32 = rand::random();
let secret_key: i32 = rand::random();
@@ -416,55 +357,46 @@ where
Err(_) => return Err(Error::SocketError),
};
// Authenticate admin user.
let (transaction_mode, server_info) = if admin {
// Compare server and client hashes.
let password_hash = md5_hash_password(
&config.general.admin_username,
&config.general.admin_password,
&salt,
);
let correct_user = config.general.admin_username.as_str();
let correct_password = config.general.admin_password.as_str();
// Compare server and client hashes.
let password_hash = md5_hash_password(correct_user, correct_password, &salt);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, username).await?;
wrong_password(&mut write, target_user_name).await?;
return Err(Error::ClientError);
}
(false, generate_server_info_for_admin())
}
// Authenticate normal user.
else {
let pool = match get_pool(pool_name.clone(), username.clone()) {
} else {
let target_pool = match get_pool(target_pool_name.clone(), target_user_name.clone()) {
Some(pool) => pool,
None => {
error_response(
&mut write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
pool_name, username
target_pool_name, target_user_name
),
)
.await?;
return Err(Error::ClientError);
}
};
let transaction_mode = target_pool.settings.pool_mode == "transaction";
let server_info = target_pool.server_info();
// Compare server and client hashes.
let password_hash = md5_hash_password(&username, &pool.settings.user.password, &salt);
let correct_password = target_pool.settings.user.password.as_str();
let password_hash = md5_hash_password(&target_user_name, correct_password, &salt);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, username).await?;
wrong_password(&mut write, &target_user_name).await?;
return Err(Error::ClientError);
}
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
(transaction_mode, pool.server_info())
(transaction_mode, server_info)
};
debug!("Password authentication successful");
@@ -476,25 +408,27 @@ where
trace!("Startup OK");
// Split the read and write streams
// so we can control buffering.
return Ok(Client {
read: BufReader::new(read),
write: write,
addr,
buffer: BytesMut::with_capacity(8196),
cancel_mode: false,
transaction_mode,
process_id,
secret_key,
client_server_map,
transaction_mode: transaction_mode,
process_id: process_id,
secret_key: secret_key,
client_server_map: client_server_map,
parameters: parameters.clone(),
stats: stats,
admin: admin,
last_address_id: None,
last_server_id: None,
pool_name: pool_name.clone(),
username: username.clone(),
shutdown,
connected_to_server: false,
target_pool_name: target_pool_name.clone(),
target_user_name: target_user_name.clone(),
shutdown_event_receiver: shutdown_event_receiver,
});
}
@@ -505,7 +439,7 @@ where
addr: std::net::SocketAddr,
mut bytes: BytesMut, // The rest of the startup message.
client_server_map: ClientServerMap,
shutdown: Receiver<()>,
shutdown_event_receiver: Receiver<()>,
) -> Result<Client<S, T>, Error> {
let process_id = bytes.get_i32();
let secret_key = bytes.get_i32();
@@ -516,18 +450,17 @@ where
buffer: BytesMut::with_capacity(8196),
cancel_mode: true,
transaction_mode: false,
process_id,
secret_key,
client_server_map,
process_id: process_id,
secret_key: secret_key,
client_server_map: client_server_map,
parameters: HashMap::new(),
stats: get_reporter(),
admin: false,
last_address_id: None,
last_server_id: None,
pool_name: String::from("undefined"),
username: String::from("undefined"),
shutdown,
connected_to_server: false,
target_pool_name: String::from("undefined"),
target_user_name: String::from("undefined"),
shutdown_event_receiver: shutdown_event_receiver,
});
}
@@ -548,7 +481,7 @@ where
process_id.clone(),
secret_key.clone(),
address.clone(),
*port,
port.clone(),
),
// The client doesn't know / got the wrong server,
@@ -560,12 +493,13 @@ where
// Opens a new separate connection to the server, sends the backend_id
// and secret_key and then closes it for security reasons. No other interactions
// take place.
return Ok(Server::cancel(&address, port, process_id, secret_key).await?);
return Ok(Server::cancel(&address, &port, process_id, secret_key).await?);
}
// The query router determines where the query is going to go,
// e.g. primary, replica, which shard.
let mut query_router = QueryRouter::new();
let mut round_robin = rand::random();
// Our custom protocol loop.
// We expect the client to either start a transaction with regular queries
@@ -583,19 +517,9 @@ where
// SET SHARDING KEY TO 'bigint';
let mut message = tokio::select! {
_ = self.shutdown.recv() => {
if !self.admin {
error_response_terminal(
&mut self.write,
&format!("terminating connection due to administrator command")
).await?;
return Ok(())
}
// Admin clients ignore shutdown.
else {
read_message(&mut self.read).await?
}
_ = self.shutdown_event_receiver.recv() => {
error_response_terminal(&mut self.write, &format!("terminating connection due to administrator command")).await?;
return Ok(())
},
message_result = read_message(&mut self.read) => message_result?
};
@@ -616,14 +540,15 @@ where
// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
let pool = match get_pool(self.pool_name.clone(), self.username.clone()) {
let pool = match get_pool(self.target_pool_name.clone(), self.target_user_name.clone())
{
Some(pool) => pool,
None => {
error_response(
&mut self.write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
self.pool_name, self.username
self.target_pool_name, self.target_user_name
),
)
.await?;
@@ -644,8 +569,8 @@ where
// SET SHARD TO
Some((Command::SetShard, _)) => {
let shard = query_router.shard();
if shard >= pool.shards() {
// Selected shard is not configured.
if query_router.shard() >= pool.shards() {
// Set the shard back to what it was.
query_router.set_shard(current_shard);
@@ -653,7 +578,7 @@ where
&mut self.write,
&format!(
"shard {} is more than configured {}, staying on shard {}",
shard,
query_router.shard(),
pool.shards(),
current_shard,
),
@@ -706,7 +631,12 @@ where
// Grab a server from the pool.
let connection = match pool
.get(query_router.shard(), query_router.role(), self.process_id)
.get(
query_router.shard(),
query_router.role(),
self.process_id,
round_robin,
)
.await
{
Ok(conn) => {
@@ -714,22 +644,9 @@ where
conn
}
Err(err) => {
// Clients do not expect to get SystemError followed by ReadyForQuery in the middle
// of extended protocol submission. So we will hold off on sending the actual error
// message to the client until we get 'S' message
match message[0] as char {
'P' | 'B' | 'E' | 'D' => (),
_ => {
error_response(
&mut self.write,
"could not get connection from the pool",
)
.await?;
}
};
error!("Could not get connection from pool: {:?}", err);
error_response(&mut self.write, "could not get connection from the pool")
.await?;
continue;
}
};
@@ -738,10 +655,11 @@ where
let address = connection.1;
let server = &mut *reference;
round_robin += 1;
// Server is assigned to the client in case the client wants to
// cancel a query later.
server.claim(self.process_id, self.secret_key);
self.connected_to_server = true;
// Update statistics.
if let Some(last_address_id) = self.last_address_id {
@@ -749,6 +667,7 @@ where
.client_disconnecting(self.process_id, last_address_id);
}
self.stats.client_active(self.process_id, address.id);
self.stats.server_active(server.process_id(), address.id);
self.last_address_id = Some(address.id);
self.last_server_id = Some(server.process_id());
@@ -812,8 +731,43 @@ where
'Q' => {
debug!("Sending query to server");
self.send_and_receive_loop(code, original, server, &address, &pool)
.await?;
self.send_server_message(
server,
original,
&address,
query_router.shard(),
&pool,
)
.await?;
// Read all data the server has to offer, which can be multiple messages
// buffered in 8196 bytes chunks.
loop {
let response = self
.receive_server_message(
server,
&address,
query_router.shard(),
&pool,
)
.await?;
// Send server reply to the client.
match write_all_half(&mut self.write, response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
return Err(err);
}
};
if !server.is_data_available() {
break;
}
}
// Report query executed statistics.
self.stats.query(self.process_id, address.id);
if !server.in_transaction() {
// Report transaction executed statistics.
@@ -822,6 +776,7 @@ where
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
self.stats.server_idle(server.process_id(), address.id);
break;
}
}
@@ -875,23 +830,52 @@ where
self.buffer.put(&original[..]);
self.send_and_receive_loop(
code,
self.buffer.clone(),
self.send_server_message(
server,
self.buffer.clone(),
&address,
query_router.shard(),
&pool,
)
.await?;
self.buffer.clear();
// Read all data the server has to offer, which can be multiple messages
// buffered in 8196 bytes chunks.
loop {
let response = self
.receive_server_message(
server,
&address,
query_router.shard(),
&pool,
)
.await?;
match write_all_half(&mut self.write, response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
return Err(err);
}
};
if !server.is_data_available() {
break;
}
}
// Report query executed statistics.
self.stats.query(self.process_id, address.id);
if !server.in_transaction() {
self.stats.transaction(self.process_id, address.id);
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
self.stats.server_idle(server.process_id(), address.id);
break;
}
}
@@ -901,17 +885,31 @@ where
'd' => {
// Forward the data to the server,
// don't buffer it since it can be rather large.
self.send_server_message(server, original, &address, &pool)
.await?;
self.send_server_message(
server,
original,
&address,
query_router.shard(),
&pool,
)
.await?;
}
// CopyDone or CopyFail
// Copy is done, successfully or not.
'c' | 'f' => {
self.send_server_message(server, original, &address, &pool)
.await?;
self.send_server_message(
server,
original,
&address,
query_router.shard(),
&pool,
)
.await?;
let response = self.receive_server_message(server, &address, &pool).await?;
let response = self
.receive_server_message(server, &address, query_router.shard(), &pool)
.await?;
match write_all_half(&mut self.write, response).await {
Ok(_) => (),
@@ -927,6 +925,7 @@ where
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
self.stats.server_idle(server.process_id(), address.id);
break;
}
}
@@ -942,8 +941,6 @@ where
// The server is no longer bound to us, we can't cancel it's queries anymore.
debug!("Releasing server back into the pool");
self.stats.server_idle(server.process_id(), address.id);
self.connected_to_server = false;
self.release();
self.stats.client_idle(self.process_id, address.id);
}
@@ -955,54 +952,18 @@ where
guard.remove(&(self.process_id, self.secret_key));
}
async fn send_and_receive_loop(
&mut self,
code: char,
message: BytesMut,
server: &mut Server,
address: &Address,
pool: &ConnectionPool,
) -> Result<(), Error> {
debug!("Sending {} to server", code);
self.send_server_message(server, message, &address, &pool)
.await?;
// Read all data the server has to offer, which can be multiple messages
// buffered in 8196 bytes chunks.
loop {
let response = self.receive_server_message(server, &address, &pool).await?;
match write_all_half(&mut self.write, response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
return Err(err);
}
};
if !server.is_data_available() {
break;
}
}
// Report query executed statistics.
self.stats.query(self.process_id, address.id);
Ok(())
}
async fn send_server_message(
&self,
server: &mut Server,
message: BytesMut,
address: &Address,
shard: usize,
pool: &ConnectionPool,
) -> Result<(), Error> {
match server.send(message).await {
Ok(_) => Ok(()),
Err(err) => {
pool.ban(address, self.process_id);
pool.ban(address, shard, self.process_id);
Err(err)
}
}
@@ -1012,6 +973,7 @@ where
&mut self,
server: &mut Server,
address: &Address,
shard: usize,
pool: &ConnectionPool,
) -> Result<BytesMut, Error> {
if pool.settings.user.statement_timeout > 0 {
@@ -1024,7 +986,7 @@ where
Ok(result) => match result {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, self.process_id);
pool.ban(address, shard, self.process_id);
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),
@@ -1039,7 +1001,7 @@ where
address, pool.settings.user.username
);
server.mark_bad();
pool.ban(address, self.process_id);
pool.ban(address, shard, self.process_id);
error_response_terminal(&mut self.write, "pool statement timeout").await?;
Err(Error::StatementTimeout)
}
@@ -1048,7 +1010,7 @@ where
match server.recv().await {
Ok(message) => Ok(message),
Err(err) => {
pool.ban(address, self.process_id);
pool.ban(address, shard, self.process_id);
error_response_terminal(
&mut self.write,
&format!("error receiving data from server: {:?}", err),
@@ -1066,16 +1028,15 @@ impl<S, T> Drop for Client<S, T> {
let mut guard = self.client_server_map.lock();
guard.remove(&(self.process_id, self.secret_key));
// Dirty shutdown
// TODO: refactor, this is not the best way to handle state management.
// Update statistics.
if let Some(address_id) = self.last_address_id {
self.stats.client_disconnecting(self.process_id, address_id);
if self.connected_to_server {
if let Some(process_id) = self.last_server_id {
self.stats.server_idle(process_id, address_id);
}
if let Some(process_id) = self.last_server_id {
self.stats.server_idle(process_id, address_id);
}
}
// self.release();
}
}

View File

@@ -57,38 +57,13 @@ impl PartialEq<Role> for Option<Role> {
/// Address identifying a PostgreSQL server uniquely.
#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Debug)]
pub struct Address {
/// Unique ID per addressable Postgres server.
pub id: usize,
/// Server host.
pub host: String,
/// Server port.
pub port: u16,
/// Shard number of this Postgres server.
pub port: String,
pub shard: usize,
/// The name of the Postgres database.
pub database: String,
/// Default search_path.
pub search_path: Option<String>,
/// Server role: replica, primary.
pub role: Role,
/// If it's a replica, number it for reference and failover.
pub replica_number: usize,
/// Position of the server in the pool for failover.
pub address_index: usize,
/// The name of the user configured to use this pool.
pub username: String,
/// The name of this pool (i.e. database name visible to the client).
pub pool_name: String,
}
impl Default for Address {
@@ -96,15 +71,11 @@ impl Default for Address {
Address {
id: 0,
host: String::from("127.0.0.1"),
port: 5432,
port: String::from("5432"),
shard: 0,
address_index: 0,
replica_number: 0,
database: String::from("database"),
search_path: None,
role: Role::Replica,
username: String::from("username"),
pool_name: String::from("pool_name"),
}
}
}
@@ -113,11 +84,11 @@ impl Address {
/// Address name (aka database) used in `SHOW STATS`, `SHOW DATABASES`, and `SHOW POOLS`.
pub fn name(&self) -> String {
match self.role {
Role::Primary => format!("{}_shard_{}_primary", self.pool_name, self.shard),
Role::Primary => format!("{}_shard_{}_primary", self.database, self.shard),
Role::Replica => format!(
"{}_shard_{}_replica_{}",
self.pool_name, self.shard, self.replica_number
self.database, self.shard, self.replica_number
),
}
}
@@ -210,7 +181,6 @@ impl Default for Pool {
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Shard {
pub database: String,
pub search_path: Option<String>,
pub servers: Vec<(String, u16, String)>,
}
@@ -218,7 +188,6 @@ impl Default for Shard {
fn default() -> Shard {
Shard {
servers: vec![(String::from("localhost"), 5432, String::from("primary"))],
search_path: None,
database: String::from("postgres"),
}
}
@@ -366,9 +335,9 @@ impl Config {
for (pool_name, pool_config) in &self.pools {
// TODO: Make this output prettier (maybe a table?)
info!("--- Settings for pool {} ---", pool_name);
info!(
"[pool: {}] Maximum user connections: {}",
pool_name,
"Pool size from all users: {}",
pool_config
.users
.iter()
@@ -376,39 +345,20 @@ impl Config {
.sum::<u32>()
.to_string()
);
info!("[pool: {}] Pool mode: {}", pool_name, pool_config.pool_mode);
info!(
"[pool: {}] Sharding function: {}",
pool_name, pool_config.sharding_function
);
info!(
"[pool: {}] Primary reads: {}",
pool_name, pool_config.primary_reads_enabled
);
info!(
"[pool: {}] Query router: {}",
pool_name, pool_config.query_parser_enabled
);
info!(
"[pool: {}] Number of shards: {}",
pool_name,
pool_config.shards.len()
);
info!(
"[pool: {}] Number of users: {}",
pool_name,
pool_config.users.len()
);
info!("Pool mode: {}", pool_config.pool_mode);
info!("Sharding function: {}", pool_config.sharding_function);
info!("Primary reads: {}", pool_config.primary_reads_enabled);
info!("Query router: {}", pool_config.query_parser_enabled);
// TODO: Make this prettier.
info!("Number of shards: {}", pool_config.shards.len());
info!("Number of users: {}", pool_config.users.len());
for user in &pool_config.users {
info!(
"[pool: {}][user: {}] Pool size: {}",
pool_name, user.1.username, user.1.pool_size,
"{} pool size: {}, statement timeout: {}",
user.1.username, user.1.pool_size, user.1.statement_timeout
);
info!(
"[pool: {}][user: {}] Statement timeout: {}",
pool_name, user.1.username, user.1.statement_timeout
)
}
}
}
@@ -506,18 +456,6 @@ pub async fn parse(path: &str) -> Result<(), Error> {
}
};
match pool.pool_mode.as_ref() {
"transaction" => (),
"session" => (),
other => {
error!(
"pool_mode can be 'session' or 'transaction', got: '{}'",
other
);
return Err(Error::BadConfig);
}
};
for shard in &pool.shards {
// We use addresses as unique identifiers,
// let's make sure they are unique in the config as well.

View File

@@ -12,5 +12,4 @@ pub enum Error {
ClientError,
TlsError,
StatementTimeout,
ShuttingDown,
}

View File

@@ -24,7 +24,6 @@ extern crate async_trait;
extern crate bb8;
extern crate bytes;
extern crate env_logger;
extern crate exitcode;
extern crate log;
extern crate md5;
extern crate num_cpus;
@@ -67,7 +66,6 @@ mod stats;
mod tls;
use crate::config::{get_config, reload_config, VERSION};
use crate::errors::Error;
use crate::pool::{ClientServerMap, ConnectionPool};
use crate::prometheus::start_metric_server;
use crate::stats::{Collector, Reporter, REPORTER};
@@ -79,7 +77,7 @@ async fn main() {
if !query_router::QueryRouter::setup() {
error!("Could not setup query router");
std::process::exit(exitcode::CONFIG);
return;
}
let args = std::env::args().collect::<Vec<String>>();
@@ -94,7 +92,7 @@ async fn main() {
Ok(_) => (),
Err(err) => {
error!("Config parse error: {:?}", err);
std::process::exit(exitcode::CONFIG);
return;
}
};
@@ -109,7 +107,7 @@ async fn main() {
Ok(addr) => addr,
Err(err) => {
error!("Invalid http address: {}", err);
std::process::exit(exitcode::CONFIG);
return;
}
};
tokio::task::spawn(async move {
@@ -123,7 +121,7 @@ async fn main() {
Ok(sock) => sock,
Err(err) => {
error!("Listener socket error: {:?}", err);
std::process::exit(exitcode::CONFIG);
return;
}
};
@@ -135,160 +133,171 @@ async fn main() {
let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new()));
// Statistics reporting.
let (stats_tx, stats_rx) = mpsc::channel(100_000);
REPORTER.store(Arc::new(Reporter::new(stats_tx.clone())));
let (tx, rx) = mpsc::channel(100_000);
REPORTER.store(Arc::new(Reporter::new(tx.clone())));
// Connection pool that allows to query all shards and replicas.
match ConnectionPool::from_config(client_server_map.clone()).await {
Ok(_) => (),
Err(err) => {
error!("Pool error: {:?}", err);
std::process::exit(exitcode::CONFIG);
return;
}
};
// Statistics collector task.
let collector_tx = tx.clone();
// Save these for reloading
let reload_client_server_map = client_server_map.clone();
let autoreload_client_server_map = client_server_map.clone();
tokio::task::spawn(async move {
let mut stats_collector = Collector::new(stats_rx, stats_tx.clone());
let mut stats_collector = Collector::new(rx, collector_tx);
stats_collector.collect().await;
});
info!("Config autoreloader: {}", config.general.autoreload);
info!("Waiting for clients");
let (shutdown_event_tx, mut shutdown_event_rx) = broadcast::channel::<()>(1);
let shutdown_event_tx_clone = shutdown_event_tx.clone();
// Client connection loop.
tokio::task::spawn(async move {
// Creates event subscriber for shutdown event, this is dropped when shutdown event is broadcast
let mut listener_shutdown_event_rx = shutdown_event_tx_clone.subscribe();
loop {
let client_server_map = client_server_map.clone();
// Listen for shutdown event and client connection at the same time
let (socket, addr) = tokio::select! {
_ = listener_shutdown_event_rx.recv() => {
// Exits client connection loop which drops listener, listener_shutdown_event_rx and shutdown_event_tx_clone
break;
}
listener_response = listener.accept() => {
match listener_response {
Ok((socket, addr)) => (socket, addr),
Err(err) => {
error!("{:?}", err);
continue;
}
}
}
};
// Used to signal shutdown
let client_shutdown_handler_rx = shutdown_event_tx_clone.subscribe();
// Used to signal that the task has completed
let dummy_tx = shutdown_event_tx_clone.clone();
// Handle client.
tokio::task::spawn(async move {
let start = chrono::offset::Utc::now().naive_utc();
match client::client_entrypoint(
socket,
client_server_map,
client_shutdown_handler_rx,
)
.await
{
Ok(_) => {
let duration = chrono::offset::Utc::now().naive_utc() - start;
info!(
"Client {:?} disconnected, session duration: {}",
addr,
format_duration(&duration)
);
}
Err(err) => {
debug!("Client disconnected with error {:?}", err);
}
};
// Drop this transmitter so receiver knows that the task is completed
drop(dummy_tx);
});
}
});
// Reload config:
// kill -SIGHUP $(pgrep pgcat)
tokio::task::spawn(async move {
let mut stream = unix_signal(SignalKind::hangup()).unwrap();
loop {
stream.recv().await;
info!("Reloading config");
match reload_config(reload_client_server_map.clone()).await {
Ok(_) => (),
Err(_) => continue,
};
get_config().show();
}
});
if config.general.autoreload {
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000));
tokio::task::spawn(async move {
info!("Config autoreloader started");
loop {
interval.tick().await;
match reload_config(autoreload_client_server_map.clone()).await {
Ok(changed) => {
if changed {
get_config().show()
}
}
Err(_) => (),
};
}
});
}
let mut term_signal = unix_signal(SignalKind::terminate()).unwrap();
let mut interrupt_signal = unix_signal(SignalKind::interrupt()).unwrap();
let mut sighup_signal = unix_signal(SignalKind::hangup()).unwrap();
let mut autoreload_interval = tokio::time::interval(tokio::time::Duration::from_millis(15_000));
let (shutdown_tx, _) = broadcast::channel::<()>(1);
let (drain_tx, mut drain_rx) = mpsc::channel::<i8>(2048);
let (exit_tx, mut exit_rx) = mpsc::channel::<()>(1);
info!("Waiting for clients");
tokio::select! {
// Initiate graceful shutdown sequence on sig int
_ = interrupt_signal.recv() => {
info!("Got SIGINT, waiting for client connection drain now");
let mut admin_only = false;
let mut total_clients = 0;
// Broadcast that client tasks need to finish
shutdown_event_tx.send(()).unwrap();
// Closes transmitter
drop(shutdown_event_tx);
loop {
tokio::select! {
// Reload config:
// kill -SIGHUP $(pgrep pgcat)
_ = sighup_signal.recv() => {
info!("Reloading config");
match reload_config(client_server_map.clone()).await {
Ok(_) => (),
Err(_) => (),
};
get_config().show();
},
_ = autoreload_interval.tick() => {
if config.general.autoreload {
info!("Automatically reloading config");
match reload_config(client_server_map.clone()).await {
Ok(changed) => {
if changed {
get_config().show()
}
}
Err(_) => (),
};
}
},
// Initiate graceful shutdown sequence on sig int
_ = interrupt_signal.recv() => {
info!("Got SIGINT, waiting for client connection drain now");
admin_only = true;
// Broadcast that client tasks need to finish
let _ = shutdown_tx.send(());
let exit_tx = exit_tx.clone();
let _ = drain_tx.send(0).await;
tokio::task::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(config.general.shutdown_timeout));
// First tick fires immediately.
interval.tick().await;
// Second one in the interval time.
interval.tick().await;
// We're done waiting.
error!("Timed out waiting for clients");
let _ = exit_tx.send(()).await;
});
},
_ = term_signal.recv() => break,
new_client = listener.accept() => {
let (socket, addr) = match new_client {
Ok((socket, addr)) => (socket, addr),
Err(err) => {
error!("{:?}", err);
continue;
// This is in a loop because the first event that the receiver receives will be the shutdown event
// This is not what we are waiting for instead, we want the receiver to send an error once all senders are closed which is reached after the shutdown event is received
loop {
match tokio::time::timeout(
tokio::time::Duration::from_millis(config.general.shutdown_timeout),
shutdown_event_rx.recv(),
)
.await
{
Ok(res) => match res {
Ok(_) => {}
Err(_) => break,
},
Err(_) => {
info!("Timed out while waiting for clients to shutdown");
break;
}
};
let shutdown_rx = shutdown_tx.subscribe();
let drain_tx = drain_tx.clone();
let client_server_map = client_server_map.clone();
tokio::task::spawn(async move {
let start = chrono::offset::Utc::now().naive_utc();
match client::client_entrypoint(
socket,
client_server_map,
shutdown_rx,
drain_tx,
admin_only,
)
.await
{
Ok(()) => {
let duration = chrono::offset::Utc::now().naive_utc() - start;
info!(
"Client {:?} disconnected, session duration: {}",
addr,
format_duration(&duration)
);
}
Err(err) => {
match err {
// Don't count the clients we rejected.
Error::ShuttingDown => (),
_ => {
// drain_tx.send(-1).await.unwrap();
}
}
debug!("Client disconnected with error {:?}", err);
}
};
});
}
_ = exit_rx.recv() => {
break;
}
client_ping = drain_rx.recv() => {
let client_ping = client_ping.unwrap();
total_clients += client_ping;
if total_clients == 0 && admin_only {
let _ = exit_tx.send(()).await;
}
}
}
},
_ = term_signal.recv() => (),
}
info!("Shutting down...");

View File

@@ -111,12 +111,7 @@ where
/// Send the startup packet the server. We're pretending we're a Pg client.
/// This tells the server which user we are and what database we want.
pub async fn startup(
stream: &mut TcpStream,
user: &str,
database: &str,
search_path: Option<&String>,
) -> Result<(), Error> {
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
let mut bytes = BytesMut::with_capacity(25);
bytes.put_i32(196608); // Protocol number
@@ -130,17 +125,6 @@ pub async fn startup(
bytes.put(&b"database\0"[..]);
bytes.put_slice(&database.as_bytes());
bytes.put_u8(0);
// search_path
match search_path {
Some(search_path) => {
bytes.put(&b"options\0"[..]);
bytes.put_slice(&format!("-c search_path={}", search_path).as_bytes());
bytes.put_u8(0);
}
None => (),
};
bytes.put_u8(0); // Null terminator
let len = bytes.len() as i32 + 4i32;

View File

@@ -6,80 +6,44 @@ use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock};
use rand::seq::SliceRandom;
use rand::thread_rng;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use crate::config::{get_config, Address, Role, User};
use crate::config::{get_config, Address, Role, Shard, User};
use crate::errors::Error;
use crate::server::Server;
use crate::sharding::ShardingFunction;
use crate::stats::{get_reporter, Reporter};
pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, u16)>>>;
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>;
pub type PoolMap = HashMap<(String, String), ConnectionPool>;
/// The connection pool, globally available.
/// This is atomic and safe and read-optimized.
/// The pool is recreated dynamically when the config is reloaded.
pub static POOLS: Lazy<ArcSwap<PoolMap>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default()));
/// Pool mode:
/// - transaction: server serves one transaction,
/// - session: server is attached to the client.
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum PoolMode {
Session,
Transaction,
}
impl std::fmt::Display for PoolMode {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match *self {
PoolMode::Session => write!(f, "session"),
PoolMode::Transaction => write!(f, "transaction"),
}
}
}
/// Pool settings.
#[derive(Clone, Debug)]
pub struct PoolSettings {
/// Transaction or Session.
pub pool_mode: PoolMode,
// Number of shards.
pub shards: usize,
// Connecting user.
pub pool_mode: String,
pub shards: HashMap<String, Shard>,
pub user: User,
// Default server role to connect to.
pub default_role: Option<Role>,
// Enable/disable query parser.
pub default_role: String,
pub query_parser_enabled: bool,
// Read from the primary as well or not.
pub primary_reads_enabled: bool,
// Sharding function.
pub sharding_function: ShardingFunction,
pub sharding_function: String,
}
impl Default for PoolSettings {
fn default() -> PoolSettings {
PoolSettings {
pool_mode: PoolMode::Transaction,
shards: 1,
pool_mode: String::from("transaction"),
shards: HashMap::from([(String::from("1"), Shard::default())]),
user: User::default(),
default_role: None,
default_role: String::from("any"),
query_parser_enabled: false,
primary_reads_enabled: true,
sharding_function: ShardingFunction::PgBigintHash,
sharding_function: "pg_bigint_hash".to_string(),
}
}
}
@@ -107,7 +71,6 @@ pub struct ConnectionPool {
/// on pool creation and save the K messages here.
server_info: BytesMut,
/// Pool configuration.
pub settings: PoolSettings,
}
@@ -115,13 +78,11 @@ impl ConnectionPool {
/// Construct the connection pool from the configuration.
pub async fn from_config(client_server_map: ClientServerMap) -> Result<(), Error> {
let config = get_config();
let mut new_pools = PoolMap::default();
let mut new_pools = HashMap::new();
let mut address_id = 0;
for (pool_name, pool_config) in &config.pools {
// There is one pool per database/user pair.
for (_, user) in &pool_config.users {
for (_user_index, user_info) in &pool_config.users {
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
@@ -135,11 +96,10 @@ impl ConnectionPool {
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
for shard_idx in &shard_ids {
let shard = &pool_config.shards[shard_idx];
for shard_idx in shard_ids {
let shard = &pool_config.shards[&shard_idx];
let mut pools = Vec::new();
let mut servers = Vec::new();
let mut address_index = 0;
let mut replica_number = 0;
for server in shard.servers.iter() {
@@ -154,20 +114,15 @@ impl ConnectionPool {
let address = Address {
id: address_id,
database: shard.database.clone(),
search_path: shard.search_path.clone(),
database: pool_name.clone(),
host: server.0.clone(),
port: server.1 as u16,
port: server.1.to_string(),
role: role,
address_index,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
};
address_id += 1;
address_index += 1;
if role == Role::Replica {
replica_number += 1;
@@ -175,14 +130,14 @@ impl ConnectionPool {
let manager = ServerPool::new(
address.clone(),
user.clone(),
user_info.clone(),
&shard.database,
client_server_map.clone(),
get_reporter(),
);
let pool = Pool::builder()
.max_size(user.pool_size)
.max_size(user_info.pool_size)
.connection_timeout(std::time::Duration::from_millis(
config.general.connect_timeout,
))
@@ -209,27 +164,13 @@ impl ConnectionPool {
stats: get_reporter(),
server_info: BytesMut::new(),
settings: PoolSettings {
pool_mode: match pool_config.pool_mode.as_str() {
"transaction" => PoolMode::Transaction,
"session" => PoolMode::Session,
_ => unreachable!(),
},
// shards: pool_config.shards.clone(),
shards: shard_ids.len(),
user: user.clone(),
default_role: match pool_config.default_role.as_str() {
"any" => None,
"replica" => Some(Role::Replica),
"primary" => Some(Role::Primary),
_ => unreachable!(),
},
pool_mode: pool_config.pool_mode.clone(),
shards: pool_config.shards.clone(),
user: user_info.clone(),
default_role: pool_config.default_role.clone(),
query_parser_enabled: pool_config.query_parser_enabled.clone(),
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: match pool_config.sharding_function.as_str() {
"pg_bigint_hash" => ShardingFunction::PgBigintHash,
"sha1" => ShardingFunction::Sha1,
_ => unreachable!(),
},
sharding_function: pool_config.sharding_function.clone(),
},
};
@@ -242,9 +183,7 @@ impl ConnectionPool {
return Err(err);
}
};
// There is one pool per database/user pair.
new_pools.insert((pool_name.clone(), user.username.clone()), pool);
new_pools.insert((pool_name.clone(), user_info.username.clone()), pool);
}
}
@@ -260,9 +199,16 @@ impl ConnectionPool {
/// the pooler starts up.
async fn validate(&mut self) -> Result<(), Error> {
let mut server_infos = Vec::new();
let stats = self.stats.clone();
for shard in 0..self.shards() {
for server in 0..self.servers(shard) {
let connection = match self.databases[shard][server].get().await {
let mut round_robin = 0;
for _ in 0..self.servers(shard) {
// To keep stats consistent.
let fake_process_id = 0;
let connection = match self.get(shard, None, fake_process_id, round_robin).await {
Ok(conn) => conn,
Err(err) => {
error!("Shard {} down or misconfigured: {:?}", shard, err);
@@ -270,21 +216,25 @@ impl ConnectionPool {
}
};
let proxy = connection;
let proxy = connection.0;
let address = connection.1;
let server = &*proxy;
let server_info = server.server_info();
stats.client_disconnecting(fake_process_id, address.id);
if server_infos.len() > 0 {
// Compare against the last server checked.
if server_info != server_infos[server_infos.len() - 1] {
warn!(
"{:?} has different server configuration than the last server",
proxy.address()
address
);
}
}
server_infos.push(server_info);
round_robin += 1;
}
}
@@ -294,8 +244,6 @@ impl ConnectionPool {
return Err(Error::AllServersDown);
}
// We're assuming all servers are identical.
// TODO: not true.
self.server_info = server_infos[0].clone();
Ok(())
@@ -304,31 +252,58 @@ impl ConnectionPool {
/// Get a connection from the pool.
pub async fn get(
&self,
shard: usize, // shard number
role: Option<Role>, // primary or replica
process_id: i32, // client id
shard: usize, // shard number
role: Option<Role>, // primary or replica
process_id: i32, // client id
mut round_robin: usize, // round robin offset
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
let now = Instant::now();
let mut candidates: Vec<&Address> = self.addresses[shard]
.iter()
.filter(|address| address.role == role)
.collect();
let addresses = &self.addresses[shard];
// Random load balancing
candidates.shuffle(&mut thread_rng());
let mut allowed_attempts = match role {
// Primary-specific queries get one attempt, if the primary is down,
// nothing we should do about it I think. It's dangerous to retry
// write queries.
Some(Role::Primary) => 1,
// Replicas get to try as many times as there are replicas
// and connections in the pool.
_ => addresses.len(),
};
debug!("Allowed attempts for {:?}: {}", role, allowed_attempts);
let exists = match role {
Some(role) => addresses.iter().filter(|addr| addr.role == role).count() > 0,
None => true,
};
if !exists {
error!("Requested role {:?}, but none are configured", role);
return Err(Error::BadConfig);
}
let healthcheck_timeout = get_config().general.healthcheck_timeout;
let healthcheck_delay = get_config().general.healthcheck_delay as u128;
while !candidates.is_empty() {
// Get the next candidate
let address = match candidates.pop() {
Some(address) => address,
None => break,
};
while allowed_attempts > 0 {
// Round-robin replicas.
round_robin += 1;
if self.is_banned(&address, role) {
debug!("Address {:?} is banned", address);
let index = round_robin % addresses.len();
let address = &addresses[index];
// Make sure you're getting a primary or a replica
// as per request. If no specific role is requested, the first
// available will be chosen.
if address.role != role {
continue;
}
allowed_attempts -= 1;
// Don't attempt to connect to banned servers.
if self.is_banned(address, shard, role) {
continue;
}
@@ -336,14 +311,12 @@ impl ConnectionPool {
self.stats.client_waiting(process_id, address.id);
// Check if we can connect
let mut conn = match self.databases[address.shard][address.address_index]
.get()
.await
{
let mut conn = match self.databases[shard][index].get().await {
Ok(conn) => conn,
Err(err) => {
error!("Banning instance {:?}, error: {:?}", address, err);
self.ban(&address, process_id);
error!("Banning replica {}, error: {:?}", index, err);
self.ban(address, shard, process_id);
self.stats.client_disconnecting(process_id, address.id);
self.stats
.checkout_time(now.elapsed().as_micros(), process_id, address.id);
continue;
@@ -357,13 +330,10 @@ impl ConnectionPool {
let require_healthcheck =
server.last_activity().elapsed().unwrap().as_millis() > healthcheck_delay;
// Do not issue a health check unless it's been a little while
// since we last checked the server is ok.
// Health checks are pretty expensive.
if !require_healthcheck {
self.stats
.checkout_time(now.elapsed().as_micros(), process_id, address.id);
self.stats.server_active(conn.process_id(), address.id);
self.stats.server_idle(conn.process_id(), address.id);
return Ok((conn, address.clone()));
}
@@ -373,7 +343,7 @@ impl ConnectionPool {
match tokio::time::timeout(
tokio::time::Duration::from_millis(healthcheck_timeout),
server.query(";"), // Cheap query (query parser not used in PG)
server.query(";"),
)
.await
{
@@ -382,72 +352,67 @@ impl ConnectionPool {
Ok(_) => {
self.stats
.checkout_time(now.elapsed().as_micros(), process_id, address.id);
self.stats.server_active(conn.process_id(), address.id);
self.stats.server_idle(conn.process_id(), address.id);
return Ok((conn, address.clone()));
}
// Health check failed.
Err(err) => {
error!(
"Banning instance {:?} because of failed health check, {:?}",
address, err
);
Err(_) => {
error!("Banning replica {} because of failed health check", index);
// Don't leave a bad connection in the pool.
server.mark_bad();
self.ban(&address, process_id);
self.ban(address, shard, process_id);
continue;
}
},
// Health check timed out.
Err(err) => {
error!(
"Banning instance {:?} because of health check timeout, {:?}",
address, err
);
Err(_) => {
error!("Banning replica {} because of health check timeout", index);
// Don't leave a bad connection in the pool.
server.mark_bad();
self.ban(&address, process_id);
self.ban(address, shard, process_id);
continue;
}
}
}
Err(Error::AllServersDown)
return Err(Error::AllServersDown);
}
/// Ban an address (i.e. replica). It no longer will serve
/// traffic for any new transactions. Existing transactions on that replica
/// will finish successfully or error out to the clients.
pub fn ban(&self, address: &Address, process_id: i32) {
pub fn ban(&self, address: &Address, shard: usize, process_id: i32) {
self.stats.client_disconnecting(process_id, address.id);
self.stats
.checkout_time(Instant::now().elapsed().as_micros(), process_id, address.id);
error!("Banning {:?}", address);
let now = chrono::offset::Utc::now().naive_utc();
let mut guard = self.banlist.write();
guard[address.shard].insert(address.clone(), now);
guard[shard].insert(address.clone(), now);
}
/// Clear the replica to receive traffic again. Takes effect immediately
/// for all new transactions.
pub fn _unban(&self, address: &Address) {
pub fn _unban(&self, address: &Address, shard: usize) {
let mut guard = self.banlist.write();
guard[address.shard].remove(address);
guard[shard].remove(address);
}
/// Check if a replica can serve traffic. If all replicas are banned,
/// we unban all of them. Better to try then not to.
pub fn is_banned(&self, address: &Address, role: Option<Role>) -> bool {
pub fn is_banned(&self, address: &Address, shard: usize, role: Option<Role>) -> bool {
let replicas_available = match role {
Some(Role::Replica) => self.addresses[address.shard]
Some(Role::Replica) => self.addresses[shard]
.iter()
.filter(|addr| addr.role == Role::Replica)
.count(),
None => self.addresses[address.shard].len(),
None => self.addresses[shard].len(),
Some(Role::Primary) => return false, // Primary cannot be banned.
};
@@ -456,17 +421,17 @@ impl ConnectionPool {
let guard = self.banlist.read();
// Everything is banned = nothing is banned.
if guard[address.shard].len() == replicas_available {
if guard[shard].len() == replicas_available {
drop(guard);
let mut guard = self.banlist.write();
guard[address.shard].clear();
guard[shard].clear();
drop(guard);
warn!("Unbanning all replicas.");
return false;
}
// I expect this to miss 99.9999% of the time.
match guard[address.shard].get(address) {
match guard[shard].get(address) {
Some(timestamp) => {
let now = chrono::offset::Utc::now().naive_utc();
let config = get_config();
@@ -476,7 +441,7 @@ impl ConnectionPool {
drop(guard);
warn!("Unbanning {:?}", address);
let mut guard = self.banlist.write();
guard[address.shard].remove(address);
guard[shard].remove(address);
false
} else {
debug!("{:?} is banned", address);
@@ -613,7 +578,6 @@ pub fn get_pool(db: String, user: String) -> Option<ConnectionPool> {
}
}
/// How many total servers we have in the config.
pub fn get_number_of_addresses() -> usize {
get_all_pools()
.iter()
@@ -621,7 +585,6 @@ pub fn get_number_of_addresses() -> usize {
.sum()
}
/// Get a pointer to all configured pools.
pub fn get_all_pools() -> HashMap<(String, String), ConnectionPool> {
return (*(*POOLS.load())).clone();
}

View File

@@ -10,7 +10,7 @@ use sqlparser::parser::Parser;
use crate::config::Role;
use crate::pool::PoolSettings;
use crate::sharding::Sharder;
use crate::sharding::{Sharder, ShardingFunction};
/// Regexes used to parse custom commands.
const CUSTOM_SQL_REGEXES: [&str; 7] = [
@@ -55,13 +55,11 @@ pub struct QueryRouter {
/// Include the primary into the replica pool for reads.
primary_reads_enabled: bool,
/// Pool configuration.
pool_settings: PoolSettings,
}
impl QueryRouter {
/// One-time initialization of regexes
/// that parse our custom SQL protocol.
/// One-time initialization of regexes.
pub fn setup() -> bool {
let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) {
Ok(rgx) => rgx,
@@ -76,7 +74,10 @@ impl QueryRouter {
.map(|rgx| Regex::new(rgx).unwrap())
.collect();
assert_eq!(list.len(), set.len());
// Impossible
if list.len() != set.len() {
return false;
}
match CUSTOM_SQL_REGEX_LIST.set(list) {
Ok(_) => true,
@@ -89,8 +90,7 @@ impl QueryRouter {
}
}
/// Create a new instance of the query router.
/// Each client gets its own.
/// Create a new instance of the query router. Each client gets its own.
pub fn new() -> QueryRouter {
QueryRouter {
active_shard: None,
@@ -101,7 +101,6 @@ impl QueryRouter {
}
}
/// Pool settings can change because of a config reload.
pub fn update_pool_settings(&mut self, pool_settings: PoolSettings) {
self.pool_settings = pool_settings;
}
@@ -137,6 +136,19 @@ impl QueryRouter {
return None;
}
let sharding_function = match self.pool_settings.sharding_function.as_ref() {
"pg_bigint_hash" => ShardingFunction::PgBigintHash,
"sha1" => ShardingFunction::Sha1,
_ => unreachable!(),
};
let default_server_role = match self.pool_settings.default_role.as_ref() {
"any" => None,
"primary" => Some(Role::Primary),
"replica" => Some(Role::Replica),
_ => unreachable!(),
};
let command = match matches[0] {
0 => Command::SetShardingKey,
1 => Command::SetShard,
@@ -188,10 +200,7 @@ impl QueryRouter {
match command {
Command::SetShardingKey => {
let sharder = Sharder::new(
self.pool_settings.shards,
self.pool_settings.sharding_function,
);
let sharder = Sharder::new(self.pool_settings.shards.len(), sharding_function);
let shard = sharder.shard(value.parse::<i64>().unwrap());
self.active_shard = Some(shard);
value = shard.to_string();
@@ -199,7 +208,7 @@ impl QueryRouter {
Command::SetShard => {
self.active_shard = match value.to_ascii_uppercase().as_ref() {
"ANY" => Some(rand::random::<usize>() % self.pool_settings.shards),
"ANY" => Some(rand::random::<usize>() % self.pool_settings.shards.len()),
_ => Some(value.parse::<usize>().unwrap()),
};
}
@@ -227,7 +236,7 @@ impl QueryRouter {
}
"default" => {
self.active_role = self.pool_settings.default_role;
self.active_role = default_server_role;
self.query_parser_enabled = self.query_parser_enabled;
self.active_role
}
@@ -358,10 +367,10 @@ impl QueryRouter {
#[cfg(test)]
mod test {
use std::collections::HashMap;
use super::*;
use crate::messages::simple_query;
use crate::pool::PoolMode;
use crate::sharding::ShardingFunction;
use bytes::BufMut;
#[test]
@@ -624,13 +633,13 @@ mod test {
QueryRouter::setup();
let pool_settings = PoolSettings {
pool_mode: PoolMode::Transaction,
shards: 0,
pool_mode: "transaction".to_string(),
shards: HashMap::default(),
user: crate::config::User::default(),
default_role: Some(Role::Replica),
default_role: Role::Replica.to_string(),
query_parser_enabled: true,
primary_reads_enabled: false,
sharding_function: ShardingFunction::PgBigintHash,
sharding_function: "pg_bigint_hash".to_string(),
};
let mut qr = QueryRouter::new();
assert_eq!(qr.active_role, None);
@@ -652,6 +661,9 @@ mod test {
let q2 = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(q2) != None);
assert_eq!(qr.active_role.unwrap(), pool_settings.clone().default_role);
assert_eq!(
qr.active_role.unwrap().to_string(),
pool_settings.clone().default_role
);
}
}

View File

@@ -75,7 +75,7 @@ impl Server {
stats: Reporter,
) -> Result<Server, Error> {
let mut stream =
match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await {
match TcpStream::connect(&format!("{}:{}", &address.host, &address.port)).await {
Ok(stream) => stream,
Err(err) => {
error!("Could not connect to server: {}", err);
@@ -86,13 +86,7 @@ impl Server {
trace!("Sending StartupMessage");
// StartupMessage
startup(
&mut stream,
&user.username,
database,
address.search_path.as_ref(),
)
.await?;
startup(&mut stream, &user.username, database).await?;
let mut server_info = BytesMut::new();
let mut process_id: i32 = 0;
@@ -348,7 +342,7 @@ impl Server {
/// Uses a separate connection that's not part of the connection pool.
pub async fn cancel(
host: &str,
port: u16,
port: &str,
process_id: i32,
secret_key: i32,
) -> Result<(), Error> {
@@ -535,7 +529,7 @@ impl Server {
self.process_id,
self.secret_key,
self.address.host.clone(),
self.address.port,
self.address.port.clone(),
),
);
}

View File

@@ -4,6 +4,7 @@ use log::{error, info, trace};
use once_cell::sync::Lazy;
use parking_lot::Mutex;
use std::collections::HashMap;
use std::time::SystemTime;
use tokio::sync::mpsc::error::TrySendError;
use tokio::sync::mpsc::{channel, Receiver, Sender};
@@ -42,6 +43,26 @@ enum EventName {
UpdateAverages,
}
/// Send an event via the channel and log
/// an error if it fails.
fn send(tx: &Sender<Event>, event: Event) {
let name = event.name;
let result = tx.try_send(event);
match result {
Ok(_) => trace!(
"{:?} event reported successfully, capacity: {}",
name,
tx.capacity()
),
Err(err) => match err {
TrySendError::Full { .. } => error!("{:?} event dropped, buffer full", name),
TrySendError::Closed { .. } => error!("{:?} event dropped, channel closed", name),
},
};
}
/// Event data sent to the collector
/// from clients and servers.
#[derive(Debug, Clone)]
@@ -80,25 +101,6 @@ impl Reporter {
Reporter { tx: tx }
}
/// Send statistics to the task keeping track of stats.
fn send(&self, event: Event) {
let name = event.name;
let result = self.tx.try_send(event);
match result {
Ok(_) => trace!(
"{:?} event reported successfully, capacity: {}",
name,
self.tx.capacity()
),
Err(err) => match err {
TrySendError::Full { .. } => error!("{:?} event dropped, buffer full", name),
TrySendError::Closed { .. } => error!("{:?} event dropped, channel closed", name),
},
};
}
/// Report a query executed by a client against
/// a server identified by the `address_id`.
pub fn query(&self, process_id: i32, address_id: usize) {
@@ -109,7 +111,7 @@ impl Reporter {
address_id: address_id,
};
self.send(event);
send(&self.tx, event);
}
/// Report a transaction executed by a client against
@@ -122,7 +124,7 @@ impl Reporter {
address_id: address_id,
};
self.send(event)
send(&self.tx, event)
}
/// Report data sent to a server identified by `address_id`.
@@ -135,7 +137,7 @@ impl Reporter {
address_id: address_id,
};
self.send(event)
send(&self.tx, event)
}
/// Report data received from a server identified by `address_id`.
@@ -148,7 +150,7 @@ impl Reporter {
address_id: address_id,
};
self.send(event)
send(&self.tx, event)
}
/// Time spent waiting to get a healthy connection from the pool
@@ -162,7 +164,7 @@ impl Reporter {
address_id: address_id,
};
self.send(event)
send(&self.tx, event)
}
/// Reports a client identified by `process_id` waiting for a connection
@@ -175,7 +177,7 @@ impl Reporter {
address_id: address_id,
};
self.send(event)
send(&self.tx, event)
}
/// Reports a client identified by `process_id` is done waiting for a connection
@@ -188,7 +190,7 @@ impl Reporter {
address_id: address_id,
};
self.send(event)
send(&self.tx, event)
}
/// Reports a client identified by `process_id` is done querying the server
@@ -201,7 +203,7 @@ impl Reporter {
address_id: address_id,
};
self.send(event)
send(&self.tx, event)
}
/// Reports a client identified by `process_id` is disconecting from the pooler.
@@ -214,7 +216,7 @@ impl Reporter {
address_id: address_id,
};
self.send(event)
send(&self.tx, event)
}
/// Reports a server connection identified by `process_id` for
@@ -228,7 +230,7 @@ impl Reporter {
address_id: address_id,
};
self.send(event)
send(&self.tx, event)
}
/// Reports a server connection identified by `process_id` for
@@ -242,7 +244,7 @@ impl Reporter {
address_id: address_id,
};
self.send(event)
send(&self.tx, event)
}
/// Reports a server connection identified by `process_id` for
@@ -256,7 +258,7 @@ impl Reporter {
address_id: address_id,
};
self.send(event)
send(&self.tx, event)
}
/// Reports a server connection identified by `process_id` for
@@ -270,7 +272,7 @@ impl Reporter {
address_id: address_id,
};
self.send(event)
send(&self.tx, event)
}
/// Reports a server connection identified by `process_id` is disconecting from the pooler.
@@ -283,7 +285,7 @@ impl Reporter {
address_id: address_id,
};
self.send(event)
send(&self.tx, event)
}
}
@@ -345,6 +347,9 @@ impl Collector {
// Track which state the client and server are at any given time.
let mut client_server_states: HashMap<usize, HashMap<i32, EventName>> = HashMap::new();
// Average update times
let mut last_updated_avg: HashMap<usize, SystemTime> = HashMap::new();
// Flush stats to StatsD and calculate averages every 15 seconds.
let tx = self.tx.clone();
tokio::task::spawn(async move {
@@ -354,12 +359,15 @@ impl Collector {
interval.tick().await;
let address_count = get_number_of_addresses();
for address_id in 0..address_count {
let _ = tx.try_send(Event {
name: EventName::UpdateStats,
value: 0,
process_id: -1,
address_id: address_id,
});
send(
&tx,
Event {
name: EventName::UpdateStats,
value: 0,
process_id: -1,
address_id: address_id,
},
);
}
}
});
@@ -372,12 +380,15 @@ impl Collector {
interval.tick().await;
let address_count = get_number_of_addresses();
for address_id in 0..address_count {
let _ = tx.try_send(Event {
name: EventName::UpdateAverages,
value: 0,
process_id: -1,
address_id: address_id,
});
send(
&tx,
Event {
name: EventName::UpdateAverages,
value: 0,
process_id: -1,
address_id: address_id,
},
);
}
}
});
@@ -399,6 +410,9 @@ impl Collector {
.entry(stat.address_id)
.or_insert(HashMap::new());
let old_stats = old_stats.entry(stat.address_id).or_insert(HashMap::new());
let last_updated_avg = last_updated_avg
.entry(stat.address_id)
.or_insert(SystemTime::now());
// Some are counters, some are gauges...
match stat.name {
@@ -524,6 +538,24 @@ impl Collector {
}
EventName::UpdateAverages => {
let elapsed = match last_updated_avg.elapsed() {
Ok(elapsed) => elapsed.as_secs(),
Err(err) => {
error!(
"Could not get elapsed time, averages may be incorrect: {:?}",
err
);
STAT_PERIOD / 1_000
}
} as i64;
*last_updated_avg = SystemTime::now();
// Tokio triggers the interval on first tick and then sleeps.
if elapsed == 0 {
continue;
}
// Calculate averages
for stat in &[
"avg_query_count",
@@ -541,7 +573,7 @@ impl Collector {
let old_value = old_stats.entry(total_name.clone()).or_insert(0);
let new_value = stats.get(total_name.as_str()).unwrap_or(&0).to_owned();
let avg = (new_value - *old_value) / (STAT_PERIOD as i64 / 1_000); // Avg / second
let avg = (new_value - *old_value) / elapsed; // Avg / second
stats.insert(stat, avg);
*old_value = new_value;

View File

@@ -14,7 +14,6 @@ PGCAT_PORT = "6432"
def pgcat_start():
pg_cat_send_signal(signal.SIGTERM)
os.system("./target/debug/pgcat .circleci/pgcat.toml &")
time.sleep(2)
def pg_cat_send_signal(signal: signal.Signals):
@@ -28,23 +27,11 @@ def pg_cat_send_signal(signal: signal.Signals):
raise Exception("pgcat not closed after SIGTERM")
def connect_db(
autocommit: bool = True,
admin: bool = False,
def connect_normal_db(
autocommit: bool = False,
) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]:
if admin:
user = "admin_user"
password = "admin_pass"
db = "pgcat"
else:
user = "sharding_user"
password = "sharding_user"
db = "sharded_db"
conn = psycopg2.connect(
f"postgres://{user}:{password}@{PGCAT_HOST}:{PGCAT_PORT}/{db}?application_name=testing_pgcat",
connect_timeout=2,
f"postgres://sharding_user:sharding_user@{PGCAT_HOST}:{PGCAT_PORT}/sharded_db?application_name=testing_pgcat"
)
conn.autocommit = autocommit
cur = conn.cursor()
@@ -58,7 +45,7 @@ def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.
def test_normal_db_access():
conn, cur = connect_db(autocommit=False)
conn, cur = connect_normal_db()
cur.execute("SELECT 1")
res = cur.fetchall()
print(res)
@@ -66,7 +53,11 @@ def test_normal_db_access():
def test_admin_db_access():
conn, cur = connect_db(admin=True)
conn = psycopg2.connect(
f"postgres://admin_user:admin_pass@{PGCAT_HOST}:{PGCAT_PORT}/pgcat"
)
conn.autocommit = True # BEGIN/COMMIT is not supported by admin db
cur = conn.cursor()
cur.execute("SHOW POOLS")
res = cur.fetchall()
@@ -76,14 +67,15 @@ def test_admin_db_access():
def test_shutdown_logic():
# - - - - - - - - - - - - - - - - - -
# NO ACTIVE QUERIES SIGINT HANDLING
##### NO ACTIVE QUERIES SIGINT HANDLING #####
# Start pgcat
pgcat_start()
# Wait for server to fully start up
time.sleep(2)
# Create client connection and send query (not in transaction)
conn, cur = connect_db()
conn, cur = connect_normal_db(True)
cur.execute("BEGIN;")
cur.execute("SELECT 1;")
@@ -105,14 +97,17 @@ def test_shutdown_logic():
cleanup_conn(conn, cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
# HANDLE TRANSACTION WITH SIGINT
##### END #####
##### HANDLE TRANSACTION WITH SIGINT #####
# Start pgcat
pgcat_start()
# Wait for server to fully start up
time.sleep(2)
# Create client connection and begin transaction
conn, cur = connect_db()
conn, cur = connect_normal_db(True)
cur.execute("BEGIN;")
cur.execute("SELECT 1;")
@@ -131,97 +126,17 @@ def test_shutdown_logic():
cleanup_conn(conn, cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
# NO NEW NON-ADMIN CONNECTIONS DURING SHUTDOWN
##### END #####
##### HANDLE SHUTDOWN TIMEOUT WITH SIGINT #####
# Start pgcat
pgcat_start()
# Create client connection and begin transaction
transaction_conn, transaction_cur = connect_db()
transaction_cur.execute("BEGIN;")
transaction_cur.execute("SELECT 1;")
# Send sigint to pgcat while still in transaction
pg_cat_send_signal(signal.SIGINT)
time.sleep(1)
start = time.perf_counter()
try:
conn, cur = connect_db()
cur.execute("SELECT 1;")
cleanup_conn(conn, cur)
except psycopg2.OperationalError as e:
time_taken = time.perf_counter() - start
if time_taken > 0.1:
raise Exception(
"Failed to reject connection within 0.1 seconds, got", time_taken, "seconds")
pass
else:
raise Exception("Able connect to database during shutdown")
cleanup_conn(transaction_conn, transaction_cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
# ALLOW NEW ADMIN CONNECTIONS DURING SHUTDOWN
# Start pgcat
pgcat_start()
# Create client connection and begin transaction
transaction_conn, transaction_cur = connect_db()
transaction_cur.execute("BEGIN;")
transaction_cur.execute("SELECT 1;")
# Send sigint to pgcat while still in transaction
pg_cat_send_signal(signal.SIGINT)
time.sleep(1)
try:
conn, cur = connect_db(admin=True)
cur.execute("SHOW DATABASES;")
cleanup_conn(conn, cur)
except psycopg2.OperationalError as e:
raise Exception(e)
cleanup_conn(transaction_conn, transaction_cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
# ADMIN CONNECTIONS CONTINUING TO WORK AFTER SHUTDOWN
# Start pgcat
pgcat_start()
# Create client connection and begin transaction
transaction_conn, transaction_cur = connect_db()
transaction_cur.execute("BEGIN;")
transaction_cur.execute("SELECT 1;")
admin_conn, admin_cur = connect_db(admin=True)
admin_cur.execute("SHOW DATABASES;")
# Send sigint to pgcat while still in transaction
pg_cat_send_signal(signal.SIGINT)
time.sleep(1)
try:
admin_cur.execute("SHOW DATABASES;")
except psycopg2.OperationalError as e:
raise Exception("Could not execute admin command:", e)
cleanup_conn(transaction_conn, transaction_cur)
cleanup_conn(admin_conn, admin_cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
# HANDLE SHUTDOWN TIMEOUT WITH SIGINT
# Start pgcat
pgcat_start()
# Wait for server to fully start up
time.sleep(3)
# Create client connection and begin transaction, which should prevent server shutdown unless shutdown timeout is reached
conn, cur = connect_db()
conn, cur = connect_normal_db(True)
cur.execute("BEGIN;")
cur.execute("SELECT 1;")
@@ -244,7 +159,7 @@ def test_shutdown_logic():
cleanup_conn(conn, cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
##### END #####
test_normal_db_access()

View File

@@ -5,89 +5,6 @@ require 'pg'
require 'toml'
$stdout.sync = true
$stderr.sync = true
class ConfigEditor
def initialize
@original_config_text = File.read('../../.circleci/pgcat.toml')
text_to_load = @original_config_text.gsub("5432", "\"5432\"")
@original_configs = TOML.load(text_to_load)
end
def original_configs
TOML.load(TOML::Generator.new(@original_configs).body)
end
def with_modified_configs(new_configs)
text_to_write = TOML::Generator.new(new_configs).body
text_to_write = text_to_write.gsub("\"5432\"", "5432")
File.write('../../.circleci/pgcat.toml', text_to_write)
yield
ensure
File.write('../../.circleci/pgcat.toml', @original_config_text)
end
end
def with_captured_stdout_stderr
sout = STDOUT.clone
serr = STDERR.clone
STDOUT.reopen("/tmp/out.txt", "w+")
STDERR.reopen("/tmp/err.txt", "w+")
STDOUT.sync = true
STDERR.sync = true
yield
return File.read('/tmp/out.txt'), File.read('/tmp/err.txt')
ensure
STDOUT.reopen(sout)
STDERR.reopen(serr)
end
def test_extended_protocol_pooler_errors
admin_conn = PG::connect("postgres://admin_user:admin_pass@127.0.0.1:6432/pgcat")
conf_editor = ConfigEditor.new
new_configs = conf_editor.original_configs
# shorter timeouts
new_configs["general"]["connect_timeout"] = 500
new_configs["general"]["ban_time"] = 1
new_configs["general"]["shutdown_timeout"] = 1
new_configs["pools"]["sharded_db"]["users"]["0"]["pool_size"] = 1
new_configs["pools"]["sharded_db"]["users"]["1"]["pool_size"] = 1
conf_editor.with_modified_configs(new_configs) { admin_conn.async_exec("RELOAD") }
conn_str = "postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db"
10.times do
Thread.new do
conn = PG::connect(conn_str)
conn.async_exec("SELECT pg_sleep(5)") rescue PG::SystemError
ensure
conn&.close
end
end
sleep(0.5)
conn_under_test = PG::connect(conn_str)
stdout, stderr = with_captured_stdout_stderr do
5.times do |i|
conn_under_test.async_exec("SELECT 1") rescue PG::SystemError
conn_under_test.exec_params("SELECT #{i} + $1", [i]) rescue PG::SystemError
sleep 1
end
end
raise StandardError, "Libpq got unexpected messages while idle" if stderr.include?("arrived from server while idle")
puts "Pool checkout errors not breaking clients passed"
ensure
sleep 1
admin_conn.async_exec("RELOAD") # Reset state
conn_under_test&.close
end
test_extended_protocol_pooler_errors
# Uncomment these two to see all queries.
# ActiveRecord.verbose_query_logs = true
@@ -227,6 +144,30 @@ def test_server_parameters
end
class ConfigEditor
def initialize
@original_config_text = File.read('../../.circleci/pgcat.toml')
text_to_load = @original_config_text.gsub("5432", "\"5432\"")
@original_configs = TOML.load(text_to_load)
end
def original_configs
TOML.load(TOML::Generator.new(@original_configs).body)
end
def with_modified_configs(new_configs)
text_to_write = TOML::Generator.new(new_configs).body
text_to_write = text_to_write.gsub("\"5432\"", "5432")
File.write('../../.circleci/pgcat.toml', text_to_write)
yield
ensure
File.write('../../.circleci/pgcat.toml', @original_config_text)
end
end
def test_reload_pool_recycling
admin_conn = PG::connect("postgres://admin_user:admin_pass@127.0.0.1:6432/pgcat")
server_conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db?application_name=testing_pgcat")
@@ -260,6 +201,3 @@ ensure
end
test_reload_pool_recycling