Implementing graceful shutdown (#105)

* Initial commit for graceful shutdown

* fmt

* Add .vscode to gitignore

* Updates shutdown logic to use channels

* fmt

* fmt

* Adds shutdown timeout

* Fmt and updates tomls

* Updates readme

* fmt and updates log levels

* Update python tests to test shutdown

* merge changes

* Rename listener rx and update bash to be in line with master

* Update python test bash script ordering

* Adds error response message before shutdown

* Add details on shutdown event loop

* Fixes response length for error

* Adds handler for sigterm

* Uses ready for query function and fixes number of bytes

* fmt
This commit is contained in:
zainkabani
2022-08-08 19:01:24 -04:00
committed by GitHub
parent 106ebee71c
commit 3719c22322
12 changed files with 308 additions and 47 deletions

View File

@@ -17,6 +17,9 @@ connect_timeout = 100
# How much time to give the health check query to return with a result (ms). # How much time to give the health check query to return with a result (ms).
healthcheck_timeout = 100 healthcheck_timeout = 100
# How much time to give clients during shutdown before forcibly killing client connections (ms).
shutdown_timeout = 5000
# For how long to ban a server if it fails a health check (seconds). # For how long to ban a server if it fails a health check (seconds).
ban_time = 60 # Seconds ban_time = 60 # Seconds

View File

@@ -74,12 +74,12 @@ cd ../..
# #
# Python tests # Python tests
# These tests will start and stop the pgcat server so it will need to be restarted after the tests
# #
cd tests/python pip3 install -r tests/python/requirements.txt
pip3 install -r requirements.txt python3 tests/python/tests.py
python3 tests.py
cd ../..
start_pgcat "info"
# Admin tests # Admin tests
export PGPASSWORD=admin_pass export PGPASSWORD=admin_pass

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
.idea .idea
/target /target
*.deb *.deb
.vscode

View File

@@ -47,6 +47,7 @@ psql -h 127.0.0.1 -p 6432 -c 'SELECT 1'
| `pool_mode` | The pool mode to use, i.e. `session` or `transaction`. | `transaction` | | `pool_mode` | The pool mode to use, i.e. `session` or `transaction`. | `transaction` |
| `connect_timeout` | Maximum time to establish a connection to a server (milliseconds). If reached, the server is banned and the next target is attempted. | `5000` | | `connect_timeout` | Maximum time to establish a connection to a server (milliseconds). If reached, the server is banned and the next target is attempted. | `5000` |
| `healthcheck_timeout` | Maximum time to pass a health check (`SELECT 1`, milliseconds). If reached, the server is banned and the next target is attempted. | `1000` | | `healthcheck_timeout` | Maximum time to pass a health check (`SELECT 1`, milliseconds). If reached, the server is banned and the next target is attempted. | `1000` |
| `shutdown_timeout` | Maximum time to give clients during shutdown before forcibly killing client connections (ms). | `60000` |
| `ban_time` | Ban time for a server (seconds). It won't be allowed to serve transactions until the ban expires; failover targets will be used instead. | `60` | | `ban_time` | Ban time for a server (seconds). It won't be allowed to serve transactions until the ban expires; failover targets will be used instead. | `60` |
| | | | | | | |
| **`user`** | | | | **`user`** | | |
@@ -250,6 +251,7 @@ The config can be reloaded by sending a `kill -s SIGHUP` to the process or by qu
| `pool_mode` | no | | `pool_mode` | no |
| `connect_timeout` | yes | | `connect_timeout` | yes |
| `healthcheck_timeout` | no | | `healthcheck_timeout` | no |
| `shutdown_timeout` | no |
| `ban_time` | no | | `ban_time` | no |
| `user` | yes | | `user` | yes |
| `shards` | yes | | `shards` | yes |

View File

@@ -17,6 +17,9 @@ connect_timeout = 5000
# How much time to give `SELECT 1` health check query to return with a result (ms). # How much time to give `SELECT 1` health check query to return with a result (ms).
healthcheck_timeout = 1000 healthcheck_timeout = 1000
# How much time to give clients during shutdown before forcibly killing client connections (ms).
shutdown_timeout = 60000
# For how long to ban a server if it fails a health check (seconds). # For how long to ban a server if it fails a health check (seconds).
ban_time = 60 # seconds ban_time = 60 # seconds

View File

@@ -17,6 +17,9 @@ connect_timeout = 5000
# How much time to give the health check query to return with a result (ms). # How much time to give the health check query to return with a result (ms).
healthcheck_timeout = 1000 healthcheck_timeout = 1000
# How much time to give clients during shutdown before forcibly killing client connections (ms).
shutdown_timeout = 60000
# For how long to ban a server if it fails a health check (seconds). # For how long to ban a server if it fails a health check (seconds).
ban_time = 60 # seconds ban_time = 60 # seconds

View File

@@ -4,6 +4,7 @@ use log::{debug, error, info, trace};
use std::collections::HashMap; use std::collections::HashMap;
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf}; use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::broadcast::Receiver;
use crate::admin::{generate_server_info_for_admin, handle_admin}; use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::config::get_config; use crate::config::get_config;
@@ -73,12 +74,15 @@ pub struct Client<S, T> {
last_server_id: Option<i32>, last_server_id: Option<i32>,
target_pool: ConnectionPool, target_pool: ConnectionPool,
shutdown_event_receiver: Receiver<()>,
} }
/// Client entrypoint. /// Client entrypoint.
pub async fn client_entrypoint( pub async fn client_entrypoint(
mut stream: TcpStream, mut stream: TcpStream,
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
shutdown_event_receiver: Receiver<()>,
) -> Result<(), Error> { ) -> Result<(), Error> {
// Figure out if the client wants TLS or not. // Figure out if the client wants TLS or not.
let addr = stream.peer_addr().unwrap(); let addr = stream.peer_addr().unwrap();
@@ -97,7 +101,7 @@ pub async fn client_entrypoint(
write_all(&mut stream, yes).await?; write_all(&mut stream, yes).await?;
// Negotiate TLS. // Negotiate TLS.
match startup_tls(stream, client_server_map).await { match startup_tls(stream, client_server_map, shutdown_event_receiver).await {
Ok(mut client) => { Ok(mut client) => {
info!("Client {:?} connected (TLS)", addr); info!("Client {:?} connected (TLS)", addr);
@@ -121,7 +125,16 @@ pub async fn client_entrypoint(
let (read, write) = split(stream); let (read, write) = split(stream);
// Continue with regular startup. // Continue with regular startup.
match Client::startup(read, write, addr, bytes, client_server_map).await { match Client::startup(
read,
write,
addr,
bytes,
client_server_map,
shutdown_event_receiver,
)
.await
{
Ok(mut client) => { Ok(mut client) => {
info!("Client {:?} connected (plain)", addr); info!("Client {:?} connected (plain)", addr);
@@ -142,7 +155,16 @@ pub async fn client_entrypoint(
let (read, write) = split(stream); let (read, write) = split(stream);
// Continue with regular startup. // Continue with regular startup.
match Client::startup(read, write, addr, bytes, client_server_map).await { match Client::startup(
read,
write,
addr,
bytes,
client_server_map,
shutdown_event_receiver,
)
.await
{
Ok(mut client) => { Ok(mut client) => {
info!("Client {:?} connected (plain)", addr); info!("Client {:?} connected (plain)", addr);
@@ -157,7 +179,16 @@ pub async fn client_entrypoint(
let (read, write) = split(stream); let (read, write) = split(stream);
// Continue with cancel query request. // Continue with cancel query request.
match Client::cancel(read, write, addr, bytes, client_server_map).await { match Client::cancel(
read,
write,
addr,
bytes,
client_server_map,
shutdown_event_receiver,
)
.await
{
Ok(mut client) => { Ok(mut client) => {
info!("Client {:?} issued a cancel query request", addr); info!("Client {:?} issued a cancel query request", addr);
@@ -214,6 +245,7 @@ where
pub async fn startup_tls( pub async fn startup_tls(
stream: TcpStream, stream: TcpStream,
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
shutdown_event_receiver: Receiver<()>,
) -> Result<Client<ReadHalf<TlsStream<TcpStream>>, WriteHalf<TlsStream<TcpStream>>>, Error> { ) -> Result<Client<ReadHalf<TlsStream<TcpStream>>, WriteHalf<TlsStream<TcpStream>>>, Error> {
// Negotiate TLS. // Negotiate TLS.
let tls = Tls::new()?; let tls = Tls::new()?;
@@ -237,7 +269,15 @@ pub async fn startup_tls(
Ok((ClientConnectionType::Startup, bytes)) => { Ok((ClientConnectionType::Startup, bytes)) => {
let (read, write) = split(stream); let (read, write) = split(stream);
Client::startup(read, write, addr, bytes, client_server_map).await Client::startup(
read,
write,
addr,
bytes,
client_server_map,
shutdown_event_receiver,
)
.await
} }
// Bad Postgres client. // Bad Postgres client.
@@ -258,6 +298,7 @@ where
addr: std::net::SocketAddr, addr: std::net::SocketAddr,
bytes: BytesMut, // The rest of the startup message. bytes: BytesMut, // The rest of the startup message.
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
shutdown_event_receiver: Receiver<()>,
) -> Result<Client<S, T>, Error> { ) -> Result<Client<S, T>, Error> {
let config = get_config(); let config = get_config();
let stats = get_reporter(); let stats = get_reporter();
@@ -384,6 +425,7 @@ where
last_address_id: None, last_address_id: None,
last_server_id: None, last_server_id: None,
target_pool: target_pool, target_pool: target_pool,
shutdown_event_receiver: shutdown_event_receiver,
}); });
} }
@@ -394,6 +436,7 @@ where
addr: std::net::SocketAddr, addr: std::net::SocketAddr,
mut bytes: BytesMut, // The rest of the startup message. mut bytes: BytesMut, // The rest of the startup message.
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
shutdown_event_receiver: Receiver<()>,
) -> Result<Client<S, T>, Error> { ) -> Result<Client<S, T>, Error> {
let process_id = bytes.get_i32(); let process_id = bytes.get_i32();
let secret_key = bytes.get_i32(); let secret_key = bytes.get_i32();
@@ -413,6 +456,7 @@ where
last_address_id: None, last_address_id: None,
last_server_id: None, last_server_id: None,
target_pool: ConnectionPool::default(), target_pool: ConnectionPool::default(),
shutdown_event_receiver: shutdown_event_receiver,
}); });
} }
@@ -467,7 +511,14 @@ where
// We can parse it here before grabbing a server from the pool, // We can parse it here before grabbing a server from the pool,
// in case the client is sending some custom protocol messages, e.g. // in case the client is sending some custom protocol messages, e.g.
// SET SHARDING KEY TO 'bigint'; // SET SHARDING KEY TO 'bigint';
let mut message = read_message(&mut self.read).await?;
let mut message = tokio::select! {
_ = self.shutdown_event_receiver.recv() => {
error_response_terminal(&mut self.write, &format!("terminating connection due to administrator command")).await?;
return Ok(())
},
message_result = read_message(&mut self.read) => message_result?
};
// Get a pool instance referenced by the most up-to-date // Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config // pointer. This ensures we always read the latest config

View File

@@ -119,6 +119,7 @@ pub struct General {
pub port: i16, pub port: i16,
pub connect_timeout: u64, pub connect_timeout: u64,
pub healthcheck_timeout: u64, pub healthcheck_timeout: u64,
pub shutdown_timeout: u64,
pub ban_time: i64, pub ban_time: i64,
pub autoreload: bool, pub autoreload: bool,
pub tls_certificate: Option<String>, pub tls_certificate: Option<String>,
@@ -134,6 +135,7 @@ impl Default for General {
port: 5432, port: 5432,
connect_timeout: 5000, connect_timeout: 5000,
healthcheck_timeout: 1000, healthcheck_timeout: 1000,
shutdown_timeout: 60000,
ban_time: 60, ban_time: 60,
autoreload: false, autoreload: false,
tls_certificate: None, tls_certificate: None,
@@ -273,6 +275,10 @@ impl From<&Config> for std::collections::HashMap<String, String> {
"healthcheck_timeout".to_string(), "healthcheck_timeout".to_string(),
config.general.healthcheck_timeout.to_string(), config.general.healthcheck_timeout.to_string(),
), ),
(
"shutdown_timeout".to_string(),
config.general.shutdown_timeout.to_string(),
),
("ban_time".to_string(), config.general.ban_time.to_string()), ("ban_time".to_string(), config.general.ban_time.to_string()),
]; ];
@@ -290,6 +296,7 @@ impl Config {
self.general.healthcheck_timeout self.general.healthcheck_timeout
); );
info!("Connection timeout: {}ms", self.general.connect_timeout); info!("Connection timeout: {}ms", self.general.connect_timeout);
info!("Shutdown timeout: {}ms", self.general.shutdown_timeout);
match self.general.tls_certificate.clone() { match self.general.tls_certificate.clone() {
Some(tls_certificate) => { Some(tls_certificate) => {
info!("TLS certificate: {}", tls_certificate); info!("TLS certificate: {}", tls_certificate);

View File

@@ -40,13 +40,13 @@ use log::{debug, error, info};
use parking_lot::Mutex; use parking_lot::Mutex;
use tokio::net::TcpListener; use tokio::net::TcpListener;
use tokio::{ use tokio::{
signal,
signal::unix::{signal as unix_signal, SignalKind}, signal::unix::{signal as unix_signal, SignalKind},
sync::mpsc, sync::mpsc,
}; };
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::broadcast;
mod admin; mod admin;
mod client; mod client;
@@ -139,24 +139,52 @@ async fn main() {
info!("Waiting for clients"); info!("Waiting for clients");
let (shutdown_event_tx, mut shutdown_event_rx) = broadcast::channel::<()>(1);
let shutdown_event_tx_clone = shutdown_event_tx.clone();
// Client connection loop. // Client connection loop.
tokio::task::spawn(async move { tokio::task::spawn(async move {
// Creates event subscriber for shutdown event, this is dropped when shutdown event is broadcast
let mut listener_shutdown_event_rx = shutdown_event_tx_clone.subscribe();
loop { loop {
let client_server_map = client_server_map.clone(); let client_server_map = client_server_map.clone();
let (socket, addr) = match listener.accept().await { // Listen for shutdown event and client connection at the same time
Ok((socket, addr)) => (socket, addr), let (socket, addr) = tokio::select! {
Err(err) => { _ = listener_shutdown_event_rx.recv() => {
error!("{:?}", err); // Exits client connection loop which drops listener, listener_shutdown_event_rx and shutdown_event_tx_clone
continue; break;
}
listener_response = listener.accept() => {
match listener_response {
Ok((socket, addr)) => (socket, addr),
Err(err) => {
error!("{:?}", err);
continue;
}
}
} }
}; };
// Used to signal shutdown
let client_shutdown_handler_rx = shutdown_event_tx_clone.subscribe();
// Used to signal that the task has completed
let dummy_tx = shutdown_event_tx_clone.clone();
// Handle client. // Handle client.
tokio::task::spawn(async move { tokio::task::spawn(async move {
let start = chrono::offset::Utc::now().naive_utc(); let start = chrono::offset::Utc::now().naive_utc();
match client::client_entrypoint(socket, client_server_map).await { match client::client_entrypoint(
socket,
client_server_map,
client_shutdown_handler_rx,
)
.await
{
Ok(_) => { Ok(_) => {
let duration = chrono::offset::Utc::now().naive_utc() - start; let duration = chrono::offset::Utc::now().naive_utc() - start;
@@ -171,6 +199,8 @@ async fn main() {
debug!("Client disconnected with error {:?}", err); debug!("Client disconnected with error {:?}", err);
} }
}; };
// Drop this transmitter so receiver knows that the task is completed
drop(dummy_tx);
}); });
} }
}); });
@@ -214,13 +244,41 @@ async fn main() {
}); });
} }
// Exit on Ctrl-C (SIGINT) and SIGTERM.
let mut term_signal = unix_signal(SignalKind::terminate()).unwrap(); let mut term_signal = unix_signal(SignalKind::terminate()).unwrap();
let mut interrupt_signal = unix_signal(SignalKind::interrupt()).unwrap();
tokio::select! { tokio::select! {
_ = signal::ctrl_c() => (), // Initiate graceful shutdown sequence on sig int
_ = interrupt_signal.recv() => {
info!("Got SIGINT, waiting for client connection drain now");
// Broadcast that client tasks need to finish
shutdown_event_tx.send(()).unwrap();
// Closes transmitter
drop(shutdown_event_tx);
// This is in a loop because the first event that the receiver receives will be the shutdown event
// This is not what we are waiting for instead, we want the receiver to send an error once all senders are closed which is reached after the shutdown event is received
loop {
match tokio::time::timeout(
tokio::time::Duration::from_millis(config.general.shutdown_timeout),
shutdown_event_rx.recv(),
)
.await
{
Ok(res) => match res {
Ok(_) => {}
Err(_) => break,
},
Err(_) => {
info!("Timed out while waiting for clients to shutdown");
break;
}
}
}
},
_ = term_signal.recv() => (), _ = term_signal.recv() => (),
}; }
info!("Shutting down..."); info!("Shutting down...");
} }

View File

@@ -98,7 +98,9 @@ pub async fn ready_for_query<S>(stream: &mut S) -> Result<(), Error>
where where
S: tokio::io::AsyncWrite + std::marker::Unpin, S: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let mut bytes = BytesMut::with_capacity(5); let mut bytes = BytesMut::with_capacity(
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
);
bytes.put_u8(b'Z'); bytes.put_u8(b'Z');
bytes.put_i32(5); bytes.put_i32(5);
@@ -252,18 +254,25 @@ where
res.put_i32(len); res.put_i32(len);
res.put_slice(&set_complete[..]); res.put_slice(&set_complete[..]);
// ReadyForQuery (idle) write_all_half(stream, res).await?;
res.put_u8(b'Z'); ready_for_query(stream).await
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
} }
/// Send a custom error message to the client. /// Send a custom error message to the client.
/// Tell the client we are ready for the next query and no rollback is necessary. /// Tell the client we are ready for the next query and no rollback is necessary.
/// Docs on error codes: <https://www.postgresql.org/docs/12/errcodes-appendix.html>. /// Docs on error codes: <https://www.postgresql.org/docs/12/errcodes-appendix.html>.
pub async fn error_response<S>(stream: &mut S, message: &str) -> Result<(), Error> pub async fn error_response<S>(stream: &mut S, message: &str) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
error_response_terminal(stream, message).await?;
ready_for_query(stream).await
}
/// Send a custom error message to the client.
/// Tell the client we are ready for the next query and no rollback is necessary.
/// Docs on error codes: <https://www.postgresql.org/docs/12/errcodes-appendix.html>.
pub async fn error_response_terminal<S>(stream: &mut S, message: &str) -> Result<(), Error>
where where
S: tokio::io::AsyncWrite + std::marker::Unpin, S: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
@@ -288,21 +297,12 @@ where
// No more fields follow. // No more fields follow.
error.put_u8(0); error.put_u8(0);
// Ready for query, no rollback needed (I = idle).
let mut ready_for_query = BytesMut::new();
ready_for_query.put_u8(b'Z');
ready_for_query.put_i32(5);
ready_for_query.put_u8(b'I');
// Compose the two message reply. // Compose the two message reply.
let mut res = BytesMut::with_capacity(error.len() + ready_for_query.len() + 5); let mut res = BytesMut::with_capacity(error.len() + 5);
res.put_u8(b'E'); res.put_u8(b'E');
res.put_i32(error.len() as i32 + 4); res.put_i32(error.len() as i32 + 4);
res.put(error); res.put(error);
res.put(ready_for_query);
Ok(write_all_half(stream, res).await?) Ok(write_all_half(stream, res).await?)
} }
@@ -366,12 +366,8 @@ where
// CommandComplete // CommandComplete
res.put(command_complete("SELECT 1")); res.put(command_complete("SELECT 1"));
// ReadyForQuery write_all_half(stream, res).await?;
res.put_u8(b'Z'); ready_for_query(stream).await
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
} }
pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut { pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {

View File

@@ -1 +1,2 @@
psycopg2==2.9.3 psycopg2==2.9.3
psutil==5.9.1

View File

@@ -1,22 +1,158 @@
from typing import Tuple
import psycopg2 import psycopg2
import psutil
import os
import signal
import subprocess
from threading import Thread
import time
def test_normal_db_access(): SHUTDOWN_TIMEOUT = 5
conn = psycopg2.connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db?application_name=testing_pgcat")
PGCAT_HOST = "127.0.0.1"
PGCAT_PORT = "6432"
def pgcat_start():
pg_cat_send_signal(signal.SIGTERM)
pgcat_start_command = "./target/debug/pgcat .circleci/pgcat.toml"
subprocess.Popen(pgcat_start_command.split())
def pg_cat_send_signal(signal: signal.Signals):
for proc in psutil.process_iter(["pid", "name"]):
if "pgcat" == proc.name():
os.kill(proc.pid, signal)
def connect_normal_db(
autocommit: bool = False,
) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]:
conn = psycopg2.connect(
f"postgres://sharding_user:sharding_user@{PGCAT_HOST}:{PGCAT_PORT}/sharded_db?application_name=testing_pgcat"
)
conn.autocommit = autocommit
cur = conn.cursor() cur = conn.cursor()
return (conn, cur)
def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.cursor):
cur.close()
conn.close()
def test_normal_db_access():
conn, cur = connect_normal_db()
cur.execute("SELECT 1") cur.execute("SELECT 1")
res = cur.fetchall() res = cur.fetchall()
print(res) print(res)
cleanup_conn(conn, cur)
def test_admin_db_access(): def test_admin_db_access():
conn = psycopg2.connect("postgres://admin_user:admin_pass@127.0.0.1:6432/pgcat") conn = psycopg2.connect(
conn.autocommit = True # BEGIN/COMMIT is not supported by admin db f"postgres://admin_user:admin_pass@{PGCAT_HOST}:{PGCAT_PORT}/pgcat"
)
conn.autocommit = True # BEGIN/COMMIT is not supported by admin db
cur = conn.cursor() cur = conn.cursor()
cur.execute("SHOW POOLS") cur.execute("SHOW POOLS")
res = cur.fetchall() res = cur.fetchall()
print(res) print(res)
cleanup_conn(conn, cur)
def test_shutdown_logic():
##### NO ACTIVE QUERIES SIGINT HANDLING #####
# Start pgcat
server = Thread(target=pgcat_start)
server.start()
# Wait for server to fully start up
time.sleep(2)
# Create client connection and send query (not in transaction)
conn, cur = connect_normal_db(True)
cur.execute("BEGIN;")
cur.execute("SELECT 1;")
cur.execute("COMMIT;")
# Send sigint to pgcat
pg_cat_send_signal(signal.SIGINT)
time.sleep(1)
# Check that any new queries fail after sigint since server should close with no active transactions
try:
cur.execute("SELECT 1;")
except psycopg2.OperationalError as e:
pass
else:
# Fail if query execution succeeded
raise Exception("Server not closed after sigint")
cleanup_conn(conn, cur)
##### HANDLE TRANSACTION WITH SIGINT #####
# Start pgcat
server = Thread(target=pgcat_start)
server.start()
# Wait for server to fully start up
time.sleep(2)
# Create client connection and begin transaction
conn, cur = connect_normal_db(True)
cur.execute("BEGIN;")
cur.execute("SELECT 1;")
# Send sigint to pgcat while still in transaction
pg_cat_send_signal(signal.SIGINT)
time.sleep(1)
# Check that any new queries succeed after sigint since server should still allow transaction to complete
try:
cur.execute("SELECT 1;")
except psycopg2.OperationalError as e:
# Fail if query fails since server closed
raise Exception("Server closed while in transaction", e.pgerror)
cleanup_conn(conn, cur)
##### HANDLE SHUTDOWN TIMEOUT WITH SIGINT #####
# Start pgcat
server = Thread(target=pgcat_start)
server.start()
# Wait for server to fully start up
time.sleep(3)
# Create client connection and begin transaction, which should prevent server shutdown unless shutdown timeout is reached
conn, cur = connect_normal_db(True)
cur.execute("BEGIN;")
cur.execute("SELECT 1;")
# Send sigint to pgcat while still in transaction
pg_cat_send_signal(signal.SIGINT)
# pgcat shutdown timeout is set to SHUTDOWN_TIMEOUT seconds, so we sleep for SHUTDOWN_TIMEOUT + 1 seconds
time.sleep(SHUTDOWN_TIMEOUT + 1)
# Check that any new queries succeed after sigint since server should still allow transaction to complete
try:
cur.execute("SELECT 1;")
except psycopg2.OperationalError as e:
pass
else:
# Fail if query execution succeeded
raise Exception("Server not closed after sigint and expected timeout")
cleanup_conn(conn, cur)
test_normal_db_access() test_normal_db_access()
test_admin_db_access() test_admin_db_access()
test_shutdown_logic()