mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-26 02:16:30 +00:00
Compare commits
2 Commits
mostafa_fi
...
circleci_A
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc28d68ec2 | ||
|
|
494e8126e1 |
@@ -106,7 +106,7 @@ cd ../..
|
|||||||
# These tests will start and stop the pgcat server so it will need to be restarted after the tests
|
# These tests will start and stop the pgcat server so it will need to be restarted after the tests
|
||||||
#
|
#
|
||||||
pip3 install -r tests/python/requirements.txt
|
pip3 install -r tests/python/requirements.txt
|
||||||
python3 tests/python/tests.py || exit 1
|
pytest || exit 1
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|||||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -10,4 +10,5 @@ lcov.info
|
|||||||
dev/.bash_history
|
dev/.bash_history
|
||||||
dev/cache
|
dev/cache
|
||||||
!dev/cache/.keepme
|
!dev/cache/.keepme
|
||||||
.venv
|
.venv
|
||||||
|
**/__pycache__
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ Thank you for contributing! Just a few tips here:
|
|||||||
3. Performance is important, make sure there are no regressions in your branch vs. `main`.
|
3. Performance is important, make sure there are no regressions in your branch vs. `main`.
|
||||||
|
|
||||||
## How to run the integration tests locally and iterate on them
|
## How to run the integration tests locally and iterate on them
|
||||||
We have integration tests written in Ruby, Python, Go and Rust.
|
We have integration tests written in Ruby, Python, Go and Rust.
|
||||||
Below are the steps to run them in a developer-friendly way that allows iterating and quick turnaround.
|
Below are the steps to run them in a developer-friendly way that allows iterating and quick turnaround.
|
||||||
Hear me out, this should be easy, it will involve opening a shell into a container with all the necessary dependancies available for you and you can modify the test code and immediately rerun your test in the interactive shell.
|
Hear me out, this should be easy, it will involve opening a shell into a container with all the necessary dependancies available for you and you can modify the test code and immediately rerun your test in the interactive shell.
|
||||||
|
|
||||||
@@ -21,7 +21,7 @@ Within this test environment you can modify the file in your favorite IDE and re
|
|||||||
|
|
||||||
Once the environment is ready, you can run the tests by running
|
Once the environment is ready, you can run the tests by running
|
||||||
Ruby: `cd /app/tests/ruby && bundle exec ruby <test_name>.rb --format documentation`
|
Ruby: `cd /app/tests/ruby && bundle exec ruby <test_name>.rb --format documentation`
|
||||||
Python: `cd /app && python3 tests/python/tests.py`
|
Python: `cd /app/ && pytest`
|
||||||
Rust: `cd /app/tests/rust && cargo run`
|
Rust: `cd /app/tests/rust && cargo run`
|
||||||
Go: `cd /app/tests/go && /usr/local/go/bin/go test`
|
Go: `cd /app/tests/go && /usr/local/go/bin/go test`
|
||||||
|
|
||||||
|
|||||||
@@ -1729,13 +1729,14 @@ where
|
|||||||
/// and also the pool's statement cache. Add it to extended protocol data.
|
/// and also the pool's statement cache. Add it to extended protocol data.
|
||||||
fn buffer_parse(&mut self, message: BytesMut, pool: &ConnectionPool) -> Result<(), Error> {
|
fn buffer_parse(&mut self, message: BytesMut, pool: &ConnectionPool) -> Result<(), Error> {
|
||||||
// Avoid parsing if prepared statements not enabled
|
// Avoid parsing if prepared statements not enabled
|
||||||
let client_given_name = Parse::get_name(&message)?;
|
if !self.prepared_statements_enabled {
|
||||||
if !self.prepared_statements_enabled || client_given_name.is_empty() {
|
|
||||||
debug!("Anonymous parse message");
|
debug!("Anonymous parse message");
|
||||||
self.extended_protocol_data_buffer
|
self.extended_protocol_data_buffer
|
||||||
.push_back(ExtendedProtocolData::create_new_parse(message, None));
|
.push_back(ExtendedProtocolData::create_new_parse(message, None));
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let client_given_name = Parse::get_name(&message)?;
|
||||||
let parse: Parse = (&message).try_into()?;
|
let parse: Parse = (&message).try_into()?;
|
||||||
|
|
||||||
// Compute the hash of the parse statement
|
// Compute the hash of the parse statement
|
||||||
@@ -1773,14 +1774,15 @@ where
|
|||||||
/// saved in the client cache.
|
/// saved in the client cache.
|
||||||
async fn buffer_bind(&mut self, message: BytesMut) -> Result<(), Error> {
|
async fn buffer_bind(&mut self, message: BytesMut) -> Result<(), Error> {
|
||||||
// Avoid parsing if prepared statements not enabled
|
// Avoid parsing if prepared statements not enabled
|
||||||
let client_given_name = Bind::get_name(&message)?;
|
if !self.prepared_statements_enabled {
|
||||||
if !self.prepared_statements_enabled || client_given_name.is_empty() {
|
|
||||||
debug!("Anonymous bind message");
|
debug!("Anonymous bind message");
|
||||||
self.extended_protocol_data_buffer
|
self.extended_protocol_data_buffer
|
||||||
.push_back(ExtendedProtocolData::create_new_bind(message, None));
|
.push_back(ExtendedProtocolData::create_new_bind(message, None));
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let client_given_name = Bind::get_name(&message)?;
|
||||||
|
|
||||||
match self.prepared_statements.get(&client_given_name) {
|
match self.prepared_statements.get(&client_given_name) {
|
||||||
Some((rewritten_parse, _)) => {
|
Some((rewritten_parse, _)) => {
|
||||||
let message = Bind::rename(message, &rewritten_parse.name)?;
|
let message = Bind::rename(message, &rewritten_parse.name)?;
|
||||||
@@ -1832,8 +1834,7 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
let describe: Describe = (&message).try_into()?;
|
let describe: Describe = (&message).try_into()?;
|
||||||
let client_given_name = describe.statement_name.clone();
|
if describe.target == 'P' {
|
||||||
if describe.target == 'P' || client_given_name.is_empty() {
|
|
||||||
debug!("Portal describe message");
|
debug!("Portal describe message");
|
||||||
self.extended_protocol_data_buffer
|
self.extended_protocol_data_buffer
|
||||||
.push_back(ExtendedProtocolData::create_new_describe(message, None));
|
.push_back(ExtendedProtocolData::create_new_describe(message, None));
|
||||||
@@ -1841,6 +1842,8 @@ where
|
|||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let client_given_name = describe.statement_name.clone();
|
||||||
|
|
||||||
match self.prepared_statements.get(&client_given_name) {
|
match self.prepared_statements.get(&client_given_name) {
|
||||||
Some((rewritten_parse, _)) => {
|
Some((rewritten_parse, _)) => {
|
||||||
let describe = describe.rename(&rewritten_parse.name);
|
let describe = describe.rename(&rewritten_parse.name);
|
||||||
|
|||||||
@@ -821,10 +821,10 @@ impl ExtendedProtocolData {
|
|||||||
pub struct Parse {
|
pub struct Parse {
|
||||||
code: char,
|
code: char,
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
len: u32,
|
len: i32,
|
||||||
pub name: String,
|
pub name: String,
|
||||||
query: String,
|
query: String,
|
||||||
num_params: u16,
|
num_params: i16,
|
||||||
param_types: Vec<i32>,
|
param_types: Vec<i32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -834,11 +834,12 @@ impl TryFrom<&BytesMut> for Parse {
|
|||||||
fn try_from(buf: &BytesMut) -> Result<Parse, Error> {
|
fn try_from(buf: &BytesMut) -> Result<Parse, Error> {
|
||||||
let mut cursor = Cursor::new(buf);
|
let mut cursor = Cursor::new(buf);
|
||||||
let code = cursor.get_u8() as char;
|
let code = cursor.get_u8() as char;
|
||||||
let len = cursor.get_u32();
|
let len = cursor.get_i32();
|
||||||
let name = cursor.read_string()?;
|
let name = cursor.read_string()?;
|
||||||
let query = cursor.read_string()?;
|
let query = cursor.read_string()?;
|
||||||
let num_params = cursor.get_u16();
|
let num_params = cursor.get_i16();
|
||||||
let mut param_types = Vec::new();
|
let mut param_types = Vec::new();
|
||||||
|
|
||||||
for _ in 0..num_params {
|
for _ in 0..num_params {
|
||||||
param_types.push(cursor.get_i32());
|
param_types.push(cursor.get_i32());
|
||||||
}
|
}
|
||||||
@@ -874,10 +875,10 @@ impl TryFrom<Parse> for BytesMut {
|
|||||||
+ 4 * parse.num_params as usize;
|
+ 4 * parse.num_params as usize;
|
||||||
|
|
||||||
bytes.put_u8(parse.code as u8);
|
bytes.put_u8(parse.code as u8);
|
||||||
bytes.put_u32(len as u32);
|
bytes.put_i32(len as i32);
|
||||||
bytes.put_slice(name);
|
bytes.put_slice(name);
|
||||||
bytes.put_slice(query);
|
bytes.put_slice(query);
|
||||||
bytes.put_u16(parse.num_params);
|
bytes.put_i16(parse.num_params);
|
||||||
for param in parse.param_types {
|
for param in parse.param_types {
|
||||||
bytes.put_i32(param);
|
bytes.put_i32(param);
|
||||||
}
|
}
|
||||||
@@ -944,14 +945,14 @@ impl Parse {
|
|||||||
pub struct Bind {
|
pub struct Bind {
|
||||||
code: char,
|
code: char,
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
len: u64,
|
len: i64,
|
||||||
portal: String,
|
portal: String,
|
||||||
pub prepared_statement: String,
|
pub prepared_statement: String,
|
||||||
num_param_format_codes: u16,
|
num_param_format_codes: i16,
|
||||||
param_format_codes: Vec<i16>,
|
param_format_codes: Vec<i16>,
|
||||||
num_param_values: u16,
|
num_param_values: i16,
|
||||||
param_values: Vec<(i32, BytesMut)>,
|
param_values: Vec<(i32, BytesMut)>,
|
||||||
num_result_column_format_codes: u16,
|
num_result_column_format_codes: i16,
|
||||||
result_columns_format_codes: Vec<i16>,
|
result_columns_format_codes: Vec<i16>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -961,17 +962,17 @@ impl TryFrom<&BytesMut> for Bind {
|
|||||||
fn try_from(buf: &BytesMut) -> Result<Bind, Error> {
|
fn try_from(buf: &BytesMut) -> Result<Bind, Error> {
|
||||||
let mut cursor = Cursor::new(buf);
|
let mut cursor = Cursor::new(buf);
|
||||||
let code = cursor.get_u8() as char;
|
let code = cursor.get_u8() as char;
|
||||||
let len = cursor.get_u32();
|
let len = cursor.get_i32();
|
||||||
let portal = cursor.read_string()?;
|
let portal = cursor.read_string()?;
|
||||||
let prepared_statement = cursor.read_string()?;
|
let prepared_statement = cursor.read_string()?;
|
||||||
let num_param_format_codes = cursor.get_u16();
|
let num_param_format_codes = cursor.get_i16();
|
||||||
let mut param_format_codes = Vec::new();
|
let mut param_format_codes = Vec::new();
|
||||||
|
|
||||||
for _ in 0..num_param_format_codes {
|
for _ in 0..num_param_format_codes {
|
||||||
param_format_codes.push(cursor.get_i16());
|
param_format_codes.push(cursor.get_i16());
|
||||||
}
|
}
|
||||||
|
|
||||||
let num_param_values = cursor.get_u16();
|
let num_param_values = cursor.get_i16();
|
||||||
let mut param_values = Vec::new();
|
let mut param_values = Vec::new();
|
||||||
|
|
||||||
for _ in 0..num_param_values {
|
for _ in 0..num_param_values {
|
||||||
@@ -993,7 +994,7 @@ impl TryFrom<&BytesMut> for Bind {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let num_result_column_format_codes = cursor.get_u16();
|
let num_result_column_format_codes = cursor.get_i16();
|
||||||
let mut result_columns_format_codes = Vec::new();
|
let mut result_columns_format_codes = Vec::new();
|
||||||
|
|
||||||
for _ in 0..num_result_column_format_codes {
|
for _ in 0..num_result_column_format_codes {
|
||||||
@@ -1002,7 +1003,7 @@ impl TryFrom<&BytesMut> for Bind {
|
|||||||
|
|
||||||
Ok(Bind {
|
Ok(Bind {
|
||||||
code,
|
code,
|
||||||
len: len as u64,
|
len: len as i64,
|
||||||
portal,
|
portal,
|
||||||
prepared_statement,
|
prepared_statement,
|
||||||
num_param_format_codes,
|
num_param_format_codes,
|
||||||
@@ -1041,19 +1042,19 @@ impl TryFrom<Bind> for BytesMut {
|
|||||||
len += 2 * bind.num_result_column_format_codes as usize;
|
len += 2 * bind.num_result_column_format_codes as usize;
|
||||||
|
|
||||||
bytes.put_u8(bind.code as u8);
|
bytes.put_u8(bind.code as u8);
|
||||||
bytes.put_u32(len as u32);
|
bytes.put_i32(len as i32);
|
||||||
bytes.put_slice(portal);
|
bytes.put_slice(portal);
|
||||||
bytes.put_slice(prepared_statement);
|
bytes.put_slice(prepared_statement);
|
||||||
bytes.put_u16(bind.num_param_format_codes);
|
bytes.put_i16(bind.num_param_format_codes);
|
||||||
for param_format_code in bind.param_format_codes {
|
for param_format_code in bind.param_format_codes {
|
||||||
bytes.put_i16(param_format_code);
|
bytes.put_i16(param_format_code);
|
||||||
}
|
}
|
||||||
bytes.put_u16(bind.num_param_values);
|
bytes.put_i16(bind.num_param_values);
|
||||||
for (param_len, param) in bind.param_values {
|
for (param_len, param) in bind.param_values {
|
||||||
bytes.put_i32(param_len);
|
bytes.put_i32(param_len);
|
||||||
bytes.put_slice(¶m);
|
bytes.put_slice(¶m);
|
||||||
}
|
}
|
||||||
bytes.put_u16(bind.num_result_column_format_codes);
|
bytes.put_i16(bind.num_result_column_format_codes);
|
||||||
for result_column_format_code in bind.result_columns_format_codes {
|
for result_column_format_code in bind.result_columns_format_codes {
|
||||||
bytes.put_i16(result_column_format_code);
|
bytes.put_i16(result_column_format_code);
|
||||||
}
|
}
|
||||||
@@ -1067,7 +1068,7 @@ impl Bind {
|
|||||||
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
|
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
|
||||||
let mut cursor = Cursor::new(buf);
|
let mut cursor = Cursor::new(buf);
|
||||||
// Skip the code and length
|
// Skip the code and length
|
||||||
cursor.advance(mem::size_of::<u8>() + mem::size_of::<u32>());
|
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
|
||||||
cursor.read_string()?;
|
cursor.read_string()?;
|
||||||
cursor.read_string()
|
cursor.read_string()
|
||||||
}
|
}
|
||||||
@@ -1077,17 +1078,17 @@ impl Bind {
|
|||||||
let mut cursor = Cursor::new(&buf);
|
let mut cursor = Cursor::new(&buf);
|
||||||
// Read basic data from the cursor
|
// Read basic data from the cursor
|
||||||
let code = cursor.get_u8();
|
let code = cursor.get_u8();
|
||||||
let current_len = cursor.get_u32();
|
let current_len = cursor.get_i32();
|
||||||
let portal = cursor.read_string()?;
|
let portal = cursor.read_string()?;
|
||||||
let prepared_statement = cursor.read_string()?;
|
let prepared_statement = cursor.read_string()?;
|
||||||
|
|
||||||
// Calculate new length
|
// Calculate new length
|
||||||
let new_len = current_len + new_name.len() as u32 - prepared_statement.len() as u32;
|
let new_len = current_len + new_name.len() as i32 - prepared_statement.len() as i32;
|
||||||
|
|
||||||
// Begin building the response buffer
|
// Begin building the response buffer
|
||||||
let mut response_buf = BytesMut::with_capacity(new_len as usize + 1);
|
let mut response_buf = BytesMut::with_capacity(new_len as usize + 1);
|
||||||
response_buf.put_u8(code);
|
response_buf.put_u8(code);
|
||||||
response_buf.put_u32(new_len);
|
response_buf.put_i32(new_len);
|
||||||
|
|
||||||
// Put the portal and new name into the buffer
|
// Put the portal and new name into the buffer
|
||||||
// Note: panic if the provided string contains null byte
|
// Note: panic if the provided string contains null byte
|
||||||
@@ -1111,7 +1112,7 @@ pub struct Describe {
|
|||||||
code: char,
|
code: char,
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
len: u32,
|
len: i32,
|
||||||
pub target: char,
|
pub target: char,
|
||||||
pub statement_name: String,
|
pub statement_name: String,
|
||||||
}
|
}
|
||||||
@@ -1122,7 +1123,7 @@ impl TryFrom<&BytesMut> for Describe {
|
|||||||
fn try_from(bytes: &BytesMut) -> Result<Describe, Error> {
|
fn try_from(bytes: &BytesMut) -> Result<Describe, Error> {
|
||||||
let mut cursor = Cursor::new(bytes);
|
let mut cursor = Cursor::new(bytes);
|
||||||
let code = cursor.get_u8() as char;
|
let code = cursor.get_u8() as char;
|
||||||
let len = cursor.get_u32();
|
let len = cursor.get_i32();
|
||||||
let target = cursor.get_u8() as char;
|
let target = cursor.get_u8() as char;
|
||||||
let statement_name = cursor.read_string()?;
|
let statement_name = cursor.read_string()?;
|
||||||
|
|
||||||
@@ -1145,7 +1146,7 @@ impl TryFrom<Describe> for BytesMut {
|
|||||||
let len = 4 + 1 + statement_name.len();
|
let len = 4 + 1 + statement_name.len();
|
||||||
|
|
||||||
bytes.put_u8(describe.code as u8);
|
bytes.put_u8(describe.code as u8);
|
||||||
bytes.put_u32(len as u32);
|
bytes.put_i32(len as i32);
|
||||||
bytes.put_u8(describe.target as u8);
|
bytes.put_u8(describe.target as u8);
|
||||||
bytes.put_slice(statement_name);
|
bytes.put_slice(statement_name);
|
||||||
|
|
||||||
|
|||||||
@@ -698,6 +698,7 @@ impl Server {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
trace!("Error: {}", error_code);
|
trace!("Error: {}", error_code);
|
||||||
|
|
||||||
match error_code {
|
match error_code {
|
||||||
@@ -1012,12 +1013,6 @@ impl Server {
|
|||||||
// which can leak between clients. This is a best effort to block bad clients
|
// which can leak between clients. This is a best effort to block bad clients
|
||||||
// from poisoning a transaction-mode pool by setting inappropriate session variables
|
// from poisoning a transaction-mode pool by setting inappropriate session variables
|
||||||
match command.as_str() {
|
match command.as_str() {
|
||||||
"DISCARD ALL" => {
|
|
||||||
self.clear_prepared_statement_cache();
|
|
||||||
}
|
|
||||||
"DEALLOCATE ALL" => {
|
|
||||||
self.clear_prepared_statement_cache();
|
|
||||||
}
|
|
||||||
"SET" => {
|
"SET" => {
|
||||||
// We don't detect set statements in transactions
|
// We don't detect set statements in transactions
|
||||||
// No great way to differentiate between set and set local
|
// No great way to differentiate between set and set local
|
||||||
@@ -1137,12 +1132,6 @@ impl Server {
|
|||||||
has_it
|
has_it
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clear_prepared_statement_cache(&mut self) {
|
|
||||||
if let Some(cache) = &mut self.prepared_statement_cache {
|
|
||||||
cache.clear();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option<String> {
|
fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option<String> {
|
||||||
let cache = match &mut self.prepared_statement_cache {
|
let cache = match &mut self.prepared_statement_cache {
|
||||||
Some(cache) => cache,
|
Some(cache) => cache,
|
||||||
|
|||||||
@@ -23,11 +23,11 @@ docker compose exec --workdir /app/tests/python main pip3 install -r requirement
|
|||||||
echo "Interactive test environment ready"
|
echo "Interactive test environment ready"
|
||||||
echo "To run integration tests, you can use the following commands:"
|
echo "To run integration tests, you can use the following commands:"
|
||||||
echo -e " ${BLUE}Ruby: ${RED}cd /app/tests/ruby && bundle exec ruby tests.rb --format documentation${RESET}"
|
echo -e " ${BLUE}Ruby: ${RED}cd /app/tests/ruby && bundle exec ruby tests.rb --format documentation${RESET}"
|
||||||
echo -e " ${BLUE}Python: ${RED}cd /app && python3 tests/python/tests.py${RESET}"
|
echo -e " ${BLUE}Python: ${RED}cd /app/ && pytest ${RESET}"
|
||||||
echo -e " ${BLUE}Rust: ${RED}cd /app/tests/rust && cargo run ${RESET}"
|
echo -e " ${BLUE}Rust: ${RED}cd /app/tests/rust && cargo run ${RESET}"
|
||||||
echo -e " ${BLUE}Go: ${RED}cd /app/tests/go && /usr/local/go/bin/go test${RESET}"
|
echo -e " ${BLUE}Go: ${RED}cd /app/tests/go && /usr/local/go/bin/go test${RESET}"
|
||||||
echo "the source code for tests are directly linked to the source code in the container so you can modify the code and run the tests again"
|
echo "the source code for tests are directly linked to the source code in the container so you can modify the code and run the tests again"
|
||||||
echo "You can rebuild PgCat from within the container by running"
|
echo "You can rebuild PgCat from within the container by running"
|
||||||
echo -e " ${GREEN}cargo build${RESET}"
|
echo -e " ${GREEN}cargo build${RESET}"
|
||||||
echo "and then run the tests again"
|
echo "and then run the tests again"
|
||||||
echo "==================================="
|
echo "==================================="
|
||||||
|
|||||||
0
tests/python/conftest.py
Normal file
0
tests/python/conftest.py
Normal file
@@ -1,2 +1,3 @@
|
|||||||
|
pytest
|
||||||
psycopg2==2.9.3
|
psycopg2==2.9.3
|
||||||
psutil==5.9.1
|
psutil==5.9.1
|
||||||
|
|||||||
@@ -1,83 +1,29 @@
|
|||||||
from typing import Tuple
|
|
||||||
import psycopg2
|
|
||||||
import psutil
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
|
||||||
|
import utils
|
||||||
|
|
||||||
SHUTDOWN_TIMEOUT = 5
|
SHUTDOWN_TIMEOUT = 5
|
||||||
|
|
||||||
PGCAT_HOST = "127.0.0.1"
|
|
||||||
PGCAT_PORT = "6432"
|
|
||||||
|
|
||||||
|
|
||||||
def pgcat_start():
|
|
||||||
pg_cat_send_signal(signal.SIGTERM)
|
|
||||||
os.system("./target/debug/pgcat .circleci/pgcat.toml &")
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
|
|
||||||
def pg_cat_send_signal(signal: signal.Signals):
|
|
||||||
try:
|
|
||||||
for proc in psutil.process_iter(["pid", "name"]):
|
|
||||||
if "pgcat" == proc.name():
|
|
||||||
os.kill(proc.pid, signal)
|
|
||||||
except Exception as e:
|
|
||||||
# The process can be gone when we send this signal
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
if signal == signal.SIGTERM:
|
|
||||||
# Returns 0 if pgcat process exists
|
|
||||||
time.sleep(2)
|
|
||||||
if not os.system('pgrep pgcat'):
|
|
||||||
raise Exception("pgcat not closed after SIGTERM")
|
|
||||||
|
|
||||||
|
|
||||||
def connect_db(
|
|
||||||
autocommit: bool = True,
|
|
||||||
admin: bool = False,
|
|
||||||
) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]:
|
|
||||||
|
|
||||||
if admin:
|
|
||||||
user = "admin_user"
|
|
||||||
password = "admin_pass"
|
|
||||||
db = "pgcat"
|
|
||||||
else:
|
|
||||||
user = "sharding_user"
|
|
||||||
password = "sharding_user"
|
|
||||||
db = "sharded_db"
|
|
||||||
|
|
||||||
conn = psycopg2.connect(
|
|
||||||
f"postgres://{user}:{password}@{PGCAT_HOST}:{PGCAT_PORT}/{db}?application_name=testing_pgcat",
|
|
||||||
connect_timeout=2,
|
|
||||||
)
|
|
||||||
conn.autocommit = autocommit
|
|
||||||
cur = conn.cursor()
|
|
||||||
|
|
||||||
return (conn, cur)
|
|
||||||
|
|
||||||
|
|
||||||
def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.cursor):
|
|
||||||
cur.close()
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
|
|
||||||
def test_normal_db_access():
|
def test_normal_db_access():
|
||||||
pgcat_start()
|
utils.pgcat_start()
|
||||||
conn, cur = connect_db(autocommit=False)
|
conn, cur = utils.connect_db(autocommit=False)
|
||||||
cur.execute("SELECT 1")
|
cur.execute("SELECT 1")
|
||||||
res = cur.fetchall()
|
res = cur.fetchall()
|
||||||
print(res)
|
print(res)
|
||||||
cleanup_conn(conn, cur)
|
utils.cleanup_conn(conn, cur)
|
||||||
|
|
||||||
|
|
||||||
def test_admin_db_access():
|
def test_admin_db_access():
|
||||||
conn, cur = connect_db(admin=True)
|
conn, cur = utils.connect_db(admin=True)
|
||||||
|
|
||||||
cur.execute("SHOW POOLS")
|
cur.execute("SHOW POOLS")
|
||||||
res = cur.fetchall()
|
res = cur.fetchall()
|
||||||
print(res)
|
print(res)
|
||||||
cleanup_conn(conn, cur)
|
utils.cleanup_conn(conn, cur)
|
||||||
|
|
||||||
|
|
||||||
def test_shutdown_logic():
|
def test_shutdown_logic():
|
||||||
@@ -86,17 +32,17 @@ def test_shutdown_logic():
|
|||||||
# NO ACTIVE QUERIES SIGINT HANDLING
|
# NO ACTIVE QUERIES SIGINT HANDLING
|
||||||
|
|
||||||
# Start pgcat
|
# Start pgcat
|
||||||
pgcat_start()
|
utils.pgcat_start()
|
||||||
|
|
||||||
# Create client connection and send query (not in transaction)
|
# Create client connection and send query (not in transaction)
|
||||||
conn, cur = connect_db()
|
conn, cur = utils.connect_db()
|
||||||
|
|
||||||
cur.execute("BEGIN;")
|
cur.execute("BEGIN;")
|
||||||
cur.execute("SELECT 1;")
|
cur.execute("SELECT 1;")
|
||||||
cur.execute("COMMIT;")
|
cur.execute("COMMIT;")
|
||||||
|
|
||||||
# Send sigint to pgcat
|
# Send sigint to pgcat
|
||||||
pg_cat_send_signal(signal.SIGINT)
|
utils.pg_cat_send_signal(signal.SIGINT)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
# Check that any new queries fail after sigint since server should close with no active transactions
|
# Check that any new queries fail after sigint since server should close with no active transactions
|
||||||
@@ -108,18 +54,18 @@ def test_shutdown_logic():
|
|||||||
# Fail if query execution succeeded
|
# Fail if query execution succeeded
|
||||||
raise Exception("Server not closed after sigint")
|
raise Exception("Server not closed after sigint")
|
||||||
|
|
||||||
cleanup_conn(conn, cur)
|
utils.cleanup_conn(conn, cur)
|
||||||
pg_cat_send_signal(signal.SIGTERM)
|
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||||
|
|
||||||
# - - - - - - - - - - - - - - - - - -
|
# - - - - - - - - - - - - - - - - - -
|
||||||
# NO ACTIVE QUERIES ADMIN SHUTDOWN COMMAND
|
# NO ACTIVE QUERIES ADMIN SHUTDOWN COMMAND
|
||||||
|
|
||||||
# Start pgcat
|
# Start pgcat
|
||||||
pgcat_start()
|
utils.pgcat_start()
|
||||||
|
|
||||||
# Create client connection and begin transaction
|
# Create client connection and begin transaction
|
||||||
conn, cur = connect_db()
|
conn, cur = utils.connect_db()
|
||||||
admin_conn, admin_cur = connect_db(admin=True)
|
admin_conn, admin_cur = utils.connect_db(admin=True)
|
||||||
|
|
||||||
cur.execute("BEGIN;")
|
cur.execute("BEGIN;")
|
||||||
cur.execute("SELECT 1;")
|
cur.execute("SELECT 1;")
|
||||||
@@ -138,24 +84,24 @@ def test_shutdown_logic():
|
|||||||
# Fail if query execution succeeded
|
# Fail if query execution succeeded
|
||||||
raise Exception("Server not closed after sigint")
|
raise Exception("Server not closed after sigint")
|
||||||
|
|
||||||
cleanup_conn(conn, cur)
|
utils.cleanup_conn(conn, cur)
|
||||||
cleanup_conn(admin_conn, admin_cur)
|
utils.cleanup_conn(admin_conn, admin_cur)
|
||||||
pg_cat_send_signal(signal.SIGTERM)
|
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||||
|
|
||||||
# - - - - - - - - - - - - - - - - - -
|
# - - - - - - - - - - - - - - - - - -
|
||||||
# HANDLE TRANSACTION WITH SIGINT
|
# HANDLE TRANSACTION WITH SIGINT
|
||||||
|
|
||||||
# Start pgcat
|
# Start pgcat
|
||||||
pgcat_start()
|
utils.pgcat_start()
|
||||||
|
|
||||||
# Create client connection and begin transaction
|
# Create client connection and begin transaction
|
||||||
conn, cur = connect_db()
|
conn, cur = utils.connect_db()
|
||||||
|
|
||||||
cur.execute("BEGIN;")
|
cur.execute("BEGIN;")
|
||||||
cur.execute("SELECT 1;")
|
cur.execute("SELECT 1;")
|
||||||
|
|
||||||
# Send sigint to pgcat while still in transaction
|
# Send sigint to pgcat while still in transaction
|
||||||
pg_cat_send_signal(signal.SIGINT)
|
utils.pg_cat_send_signal(signal.SIGINT)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
# Check that any new queries succeed after sigint since server should still allow transaction to complete
|
# Check that any new queries succeed after sigint since server should still allow transaction to complete
|
||||||
@@ -165,18 +111,18 @@ def test_shutdown_logic():
|
|||||||
# Fail if query fails since server closed
|
# Fail if query fails since server closed
|
||||||
raise Exception("Server closed while in transaction", e.pgerror)
|
raise Exception("Server closed while in transaction", e.pgerror)
|
||||||
|
|
||||||
cleanup_conn(conn, cur)
|
utils.cleanup_conn(conn, cur)
|
||||||
pg_cat_send_signal(signal.SIGTERM)
|
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||||
|
|
||||||
# - - - - - - - - - - - - - - - - - -
|
# - - - - - - - - - - - - - - - - - -
|
||||||
# HANDLE TRANSACTION WITH ADMIN SHUTDOWN COMMAND
|
# HANDLE TRANSACTION WITH ADMIN SHUTDOWN COMMAND
|
||||||
|
|
||||||
# Start pgcat
|
# Start pgcat
|
||||||
pgcat_start()
|
utils.pgcat_start()
|
||||||
|
|
||||||
# Create client connection and begin transaction
|
# Create client connection and begin transaction
|
||||||
conn, cur = connect_db()
|
conn, cur = utils.connect_db()
|
||||||
admin_conn, admin_cur = connect_db(admin=True)
|
admin_conn, admin_cur = utils.connect_db(admin=True)
|
||||||
|
|
||||||
cur.execute("BEGIN;")
|
cur.execute("BEGIN;")
|
||||||
cur.execute("SELECT 1;")
|
cur.execute("SELECT 1;")
|
||||||
@@ -194,30 +140,30 @@ def test_shutdown_logic():
|
|||||||
# Fail if query fails since server closed
|
# Fail if query fails since server closed
|
||||||
raise Exception("Server closed while in transaction", e.pgerror)
|
raise Exception("Server closed while in transaction", e.pgerror)
|
||||||
|
|
||||||
cleanup_conn(conn, cur)
|
utils.cleanup_conn(conn, cur)
|
||||||
cleanup_conn(admin_conn, admin_cur)
|
utils.cleanup_conn(admin_conn, admin_cur)
|
||||||
pg_cat_send_signal(signal.SIGTERM)
|
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||||
|
|
||||||
# - - - - - - - - - - - - - - - - - -
|
# - - - - - - - - - - - - - - - - - -
|
||||||
# NO NEW NON-ADMIN CONNECTIONS DURING SHUTDOWN
|
# NO NEW NON-ADMIN CONNECTIONS DURING SHUTDOWN
|
||||||
# Start pgcat
|
# Start pgcat
|
||||||
pgcat_start()
|
utils.pgcat_start()
|
||||||
|
|
||||||
# Create client connection and begin transaction
|
# Create client connection and begin transaction
|
||||||
transaction_conn, transaction_cur = connect_db()
|
transaction_conn, transaction_cur = utils.connect_db()
|
||||||
|
|
||||||
transaction_cur.execute("BEGIN;")
|
transaction_cur.execute("BEGIN;")
|
||||||
transaction_cur.execute("SELECT 1;")
|
transaction_cur.execute("SELECT 1;")
|
||||||
|
|
||||||
# Send sigint to pgcat while still in transaction
|
# Send sigint to pgcat while still in transaction
|
||||||
pg_cat_send_signal(signal.SIGINT)
|
utils.pg_cat_send_signal(signal.SIGINT)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
conn, cur = connect_db()
|
conn, cur = utils.connect_db()
|
||||||
cur.execute("SELECT 1;")
|
cur.execute("SELECT 1;")
|
||||||
cleanup_conn(conn, cur)
|
utils.cleanup_conn(conn, cur)
|
||||||
except psycopg2.OperationalError as e:
|
except psycopg2.OperationalError as e:
|
||||||
time_taken = time.perf_counter() - start
|
time_taken = time.perf_counter() - start
|
||||||
if time_taken > 0.1:
|
if time_taken > 0.1:
|
||||||
@@ -227,49 +173,49 @@ def test_shutdown_logic():
|
|||||||
else:
|
else:
|
||||||
raise Exception("Able connect to database during shutdown")
|
raise Exception("Able connect to database during shutdown")
|
||||||
|
|
||||||
cleanup_conn(transaction_conn, transaction_cur)
|
utils.cleanup_conn(transaction_conn, transaction_cur)
|
||||||
pg_cat_send_signal(signal.SIGTERM)
|
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||||
|
|
||||||
# - - - - - - - - - - - - - - - - - -
|
# - - - - - - - - - - - - - - - - - -
|
||||||
# ALLOW NEW ADMIN CONNECTIONS DURING SHUTDOWN
|
# ALLOW NEW ADMIN CONNECTIONS DURING SHUTDOWN
|
||||||
# Start pgcat
|
# Start pgcat
|
||||||
pgcat_start()
|
utils.pgcat_start()
|
||||||
|
|
||||||
# Create client connection and begin transaction
|
# Create client connection and begin transaction
|
||||||
transaction_conn, transaction_cur = connect_db()
|
transaction_conn, transaction_cur = utils.connect_db()
|
||||||
|
|
||||||
transaction_cur.execute("BEGIN;")
|
transaction_cur.execute("BEGIN;")
|
||||||
transaction_cur.execute("SELECT 1;")
|
transaction_cur.execute("SELECT 1;")
|
||||||
|
|
||||||
# Send sigint to pgcat while still in transaction
|
# Send sigint to pgcat while still in transaction
|
||||||
pg_cat_send_signal(signal.SIGINT)
|
utils.pg_cat_send_signal(signal.SIGINT)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
conn, cur = connect_db(admin=True)
|
conn, cur = utils.connect_db(admin=True)
|
||||||
cur.execute("SHOW DATABASES;")
|
cur.execute("SHOW DATABASES;")
|
||||||
cleanup_conn(conn, cur)
|
utils.cleanup_conn(conn, cur)
|
||||||
except psycopg2.OperationalError as e:
|
except psycopg2.OperationalError as e:
|
||||||
raise Exception(e)
|
raise Exception(e)
|
||||||
|
|
||||||
cleanup_conn(transaction_conn, transaction_cur)
|
utils.cleanup_conn(transaction_conn, transaction_cur)
|
||||||
pg_cat_send_signal(signal.SIGTERM)
|
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||||
|
|
||||||
# - - - - - - - - - - - - - - - - - -
|
# - - - - - - - - - - - - - - - - - -
|
||||||
# ADMIN CONNECTIONS CONTINUING TO WORK AFTER SHUTDOWN
|
# ADMIN CONNECTIONS CONTINUING TO WORK AFTER SHUTDOWN
|
||||||
# Start pgcat
|
# Start pgcat
|
||||||
pgcat_start()
|
utils.pgcat_start()
|
||||||
|
|
||||||
# Create client connection and begin transaction
|
# Create client connection and begin transaction
|
||||||
transaction_conn, transaction_cur = connect_db()
|
transaction_conn, transaction_cur = utils.connect_db()
|
||||||
transaction_cur.execute("BEGIN;")
|
transaction_cur.execute("BEGIN;")
|
||||||
transaction_cur.execute("SELECT 1;")
|
transaction_cur.execute("SELECT 1;")
|
||||||
|
|
||||||
admin_conn, admin_cur = connect_db(admin=True)
|
admin_conn, admin_cur = utils.connect_db(admin=True)
|
||||||
admin_cur.execute("SHOW DATABASES;")
|
admin_cur.execute("SHOW DATABASES;")
|
||||||
|
|
||||||
# Send sigint to pgcat while still in transaction
|
# Send sigint to pgcat while still in transaction
|
||||||
pg_cat_send_signal(signal.SIGINT)
|
utils.pg_cat_send_signal(signal.SIGINT)
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -277,24 +223,24 @@ def test_shutdown_logic():
|
|||||||
except psycopg2.OperationalError as e:
|
except psycopg2.OperationalError as e:
|
||||||
raise Exception("Could not execute admin command:", e)
|
raise Exception("Could not execute admin command:", e)
|
||||||
|
|
||||||
cleanup_conn(transaction_conn, transaction_cur)
|
utils.cleanup_conn(transaction_conn, transaction_cur)
|
||||||
cleanup_conn(admin_conn, admin_cur)
|
utils.cleanup_conn(admin_conn, admin_cur)
|
||||||
pg_cat_send_signal(signal.SIGTERM)
|
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||||
|
|
||||||
# - - - - - - - - - - - - - - - - - -
|
# - - - - - - - - - - - - - - - - - -
|
||||||
# HANDLE SHUTDOWN TIMEOUT WITH SIGINT
|
# HANDLE SHUTDOWN TIMEOUT WITH SIGINT
|
||||||
|
|
||||||
# Start pgcat
|
# Start pgcat
|
||||||
pgcat_start()
|
utils.pgcat_start()
|
||||||
|
|
||||||
# Create client connection and begin transaction, which should prevent server shutdown unless shutdown timeout is reached
|
# Create client connection and begin transaction, which should prevent server shutdown unless shutdown timeout is reached
|
||||||
conn, cur = connect_db()
|
conn, cur = utils.connect_db()
|
||||||
|
|
||||||
cur.execute("BEGIN;")
|
cur.execute("BEGIN;")
|
||||||
cur.execute("SELECT 1;")
|
cur.execute("SELECT 1;")
|
||||||
|
|
||||||
# Send sigint to pgcat while still in transaction
|
# Send sigint to pgcat while still in transaction
|
||||||
pg_cat_send_signal(signal.SIGINT)
|
utils.pg_cat_send_signal(signal.SIGINT)
|
||||||
|
|
||||||
# pgcat shutdown timeout is set to SHUTDOWN_TIMEOUT seconds, so we sleep for SHUTDOWN_TIMEOUT + 1 seconds
|
# pgcat shutdown timeout is set to SHUTDOWN_TIMEOUT seconds, so we sleep for SHUTDOWN_TIMEOUT + 1 seconds
|
||||||
time.sleep(SHUTDOWN_TIMEOUT + 1)
|
time.sleep(SHUTDOWN_TIMEOUT + 1)
|
||||||
@@ -308,12 +254,5 @@ def test_shutdown_logic():
|
|||||||
# Fail if query execution succeeded
|
# Fail if query execution succeeded
|
||||||
raise Exception("Server not closed after sigint and expected timeout")
|
raise Exception("Server not closed after sigint and expected timeout")
|
||||||
|
|
||||||
cleanup_conn(conn, cur)
|
utils.cleanup_conn(conn, cur)
|
||||||
pg_cat_send_signal(signal.SIGTERM)
|
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||||
|
|
||||||
# - - - - - - - - - - - - - - - - - -
|
|
||||||
|
|
||||||
|
|
||||||
test_normal_db_access()
|
|
||||||
test_admin_db_access()
|
|
||||||
test_shutdown_logic()
|
|
||||||
60
tests/python/utils.py
Normal file
60
tests/python/utils.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
import os
|
||||||
|
import psutil
|
||||||
|
import signal
|
||||||
|
import time
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
|
||||||
|
PGCAT_HOST = "127.0.0.1"
|
||||||
|
PGCAT_PORT = "6432"
|
||||||
|
|
||||||
|
def pgcat_start():
|
||||||
|
pg_cat_send_signal(signal.SIGTERM)
|
||||||
|
os.system("./target/debug/pgcat .circleci/pgcat.toml &")
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
|
||||||
|
def pg_cat_send_signal(signal: signal.Signals):
|
||||||
|
try:
|
||||||
|
for proc in psutil.process_iter(["pid", "name"]):
|
||||||
|
if "pgcat" == proc.name():
|
||||||
|
os.kill(proc.pid, signal)
|
||||||
|
except Exception as e:
|
||||||
|
# The process can be gone when we send this signal
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
if signal == signal.SIGTERM:
|
||||||
|
# Returns 0 if pgcat process exists
|
||||||
|
time.sleep(2)
|
||||||
|
if not os.system('pgrep pgcat'):
|
||||||
|
raise Exception("pgcat not closed after SIGTERM")
|
||||||
|
|
||||||
|
|
||||||
|
def connect_db(
|
||||||
|
autocommit: bool = True,
|
||||||
|
admin: bool = False,
|
||||||
|
) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]:
|
||||||
|
|
||||||
|
if admin:
|
||||||
|
user = "admin_user"
|
||||||
|
password = "admin_pass"
|
||||||
|
db = "pgcat"
|
||||||
|
else:
|
||||||
|
user = "sharding_user"
|
||||||
|
password = "sharding_user"
|
||||||
|
db = "sharded_db"
|
||||||
|
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
f"postgres://{user}:{password}@{PGCAT_HOST}:{PGCAT_PORT}/{db}?application_name=testing_pgcat",
|
||||||
|
connect_timeout=2,
|
||||||
|
)
|
||||||
|
conn.autocommit = autocommit
|
||||||
|
cur = conn.cursor()
|
||||||
|
|
||||||
|
return (conn, cur)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.cursor):
|
||||||
|
cur.close()
|
||||||
|
conn.close()
|
||||||
@@ -1,145 +0,0 @@
|
|||||||
|
|
||||||
class PostgresMessage
|
|
||||||
# Base class for common functionality
|
|
||||||
|
|
||||||
def encode_string(str)
|
|
||||||
"#{str}\0" # Encode a string with a null terminator
|
|
||||||
end
|
|
||||||
|
|
||||||
def encode_int16(value)
|
|
||||||
[value].pack('n') # Encode an Int16
|
|
||||||
end
|
|
||||||
|
|
||||||
def encode_int32(value)
|
|
||||||
[value].pack('N') # Encode an Int32
|
|
||||||
end
|
|
||||||
|
|
||||||
def message_prefix(type, length)
|
|
||||||
"#{type}#{encode_int32(length)}" # Message type and length prefix
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
class SimpleQueryMessage < PostgresMessage
|
|
||||||
attr_accessor :query
|
|
||||||
|
|
||||||
def initialize(query = "")
|
|
||||||
@query = query
|
|
||||||
end
|
|
||||||
|
|
||||||
def to_bytes
|
|
||||||
query_bytes = encode_string(@query)
|
|
||||||
length = 4 + query_bytes.size # Length includes 4 bytes for length itself
|
|
||||||
message_prefix('Q', length) + query_bytes
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
class ParseMessage < PostgresMessage
|
|
||||||
attr_accessor :statement_name, :query, :parameter_types
|
|
||||||
|
|
||||||
def initialize(statement_name = "", query = "", parameter_types = [])
|
|
||||||
@statement_name = statement_name
|
|
||||||
@query = query
|
|
||||||
@parameter_types = parameter_types
|
|
||||||
end
|
|
||||||
|
|
||||||
def to_bytes
|
|
||||||
statement_name_bytes = encode_string(@statement_name)
|
|
||||||
query_bytes = encode_string(@query)
|
|
||||||
parameter_types_bytes = @parameter_types.pack('N*')
|
|
||||||
|
|
||||||
length = 4 + statement_name_bytes.size + query_bytes.size + 2 + parameter_types_bytes.size
|
|
||||||
message_prefix('P', length) + statement_name_bytes + query_bytes + encode_int16(@parameter_types.size) + parameter_types_bytes
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
class BindMessage < PostgresMessage
|
|
||||||
attr_accessor :portal_name, :statement_name, :parameter_format_codes, :parameters, :result_column_format_codes
|
|
||||||
|
|
||||||
def initialize(portal_name = "", statement_name = "", parameter_format_codes = [], parameters = [], result_column_format_codes = [])
|
|
||||||
@portal_name = portal_name
|
|
||||||
@statement_name = statement_name
|
|
||||||
@parameter_format_codes = parameter_format_codes
|
|
||||||
@parameters = parameters
|
|
||||||
@result_column_format_codes = result_column_format_codes
|
|
||||||
end
|
|
||||||
|
|
||||||
def to_bytes
|
|
||||||
portal_name_bytes = encode_string(@portal_name)
|
|
||||||
statement_name_bytes = encode_string(@statement_name)
|
|
||||||
parameter_format_codes_bytes = @parameter_format_codes.pack('n*')
|
|
||||||
|
|
||||||
parameters_bytes = @parameters.map do |param|
|
|
||||||
if param.nil?
|
|
||||||
encode_int32(-1)
|
|
||||||
else
|
|
||||||
encode_int32(param.bytesize) + param
|
|
||||||
end
|
|
||||||
end.join
|
|
||||||
|
|
||||||
result_column_format_codes_bytes = @result_column_format_codes.pack('n*')
|
|
||||||
|
|
||||||
length = 4 + portal_name_bytes.size + statement_name_bytes.size + 2 + parameter_format_codes_bytes.size + 2 + parameters_bytes.size + 2 + result_column_format_codes_bytes.size
|
|
||||||
message_prefix('B', length) + portal_name_bytes + statement_name_bytes + encode_int16(@parameter_format_codes.size) + parameter_format_codes_bytes + encode_int16(@parameters.size) + parameters_bytes + encode_int16(@result_column_format_codes.size) + result_column_format_codes_bytes
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
class DescribeMessage < PostgresMessage
|
|
||||||
attr_accessor :type, :name
|
|
||||||
|
|
||||||
def initialize(type = 'S', name = "")
|
|
||||||
@type = type
|
|
||||||
@name = name
|
|
||||||
end
|
|
||||||
|
|
||||||
def to_bytes
|
|
||||||
name_bytes = encode_string(@name)
|
|
||||||
length = 4 + 1 + name_bytes.size
|
|
||||||
message_prefix('D', length) + @type + name_bytes
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
|
|
||||||
class ExecuteMessage < PostgresMessage
|
|
||||||
attr_accessor :portal_name, :max_rows
|
|
||||||
|
|
||||||
def initialize(portal_name = "", max_rows = 0)
|
|
||||||
@portal_name = portal_name
|
|
||||||
@max_rows = max_rows
|
|
||||||
end
|
|
||||||
|
|
||||||
def to_bytes
|
|
||||||
portal_name_bytes = encode_string(@portal_name)
|
|
||||||
length = 4 + portal_name_bytes.size + 4
|
|
||||||
message_prefix('E', length) + portal_name_bytes + encode_int32(@max_rows)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
class FlushMessage < PostgresMessage
|
|
||||||
def to_bytes
|
|
||||||
length = 4
|
|
||||||
message_prefix('H', length)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
class SyncMessage < PostgresMessage
|
|
||||||
def to_bytes
|
|
||||||
length = 4
|
|
||||||
message_prefix('S', length)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
class CloseMessage < PostgresMessage
|
|
||||||
attr_accessor :type, :name
|
|
||||||
|
|
||||||
def initialize(type = 'S', name = "")
|
|
||||||
@type = type
|
|
||||||
@name = name
|
|
||||||
end
|
|
||||||
|
|
||||||
def to_bytes
|
|
||||||
name_bytes = encode_string(@name)
|
|
||||||
length = 4 + 1 + name_bytes.size
|
|
||||||
message_prefix('C', length) + @type + name_bytes
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
require 'socket'
|
require 'socket'
|
||||||
require 'digest/md5'
|
require 'digest/md5'
|
||||||
require_relative 'frontend_messages'
|
|
||||||
|
|
||||||
BACKEND_MESSAGE_CODES = {
|
BACKEND_MESSAGE_CODES = {
|
||||||
'Z' => "ReadyForQuery",
|
'Z' => "ReadyForQuery",
|
||||||
@@ -19,13 +18,9 @@ class PostgresSocket
|
|||||||
@host = host
|
@host = host
|
||||||
@socket = TCPSocket.new @host, @port
|
@socket = TCPSocket.new @host, @port
|
||||||
@parameters = {}
|
@parameters = {}
|
||||||
@verbose = false
|
@verbose = true
|
||||||
end
|
end
|
||||||
|
|
||||||
def send_message(message)
|
|
||||||
@socket.write(message.to_bytes)
|
|
||||||
end
|
|
||||||
|
|
||||||
def send_md5_password_message(username, password, salt)
|
def send_md5_password_message(username, password, salt)
|
||||||
m = Digest::MD5.hexdigest(password + username)
|
m = Digest::MD5.hexdigest(password + username)
|
||||||
m = Digest::MD5.hexdigest(m + salt.map(&:chr).join(""))
|
m = Digest::MD5.hexdigest(m + salt.map(&:chr).join(""))
|
||||||
@@ -118,6 +113,107 @@ class PostgresSocket
|
|||||||
log "[F] Sent CancelRequest message"
|
log "[F] Sent CancelRequest message"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def send_query_message(query)
|
||||||
|
query_size = query.length
|
||||||
|
message_size = 1 + 4 + query_size
|
||||||
|
message = []
|
||||||
|
message << "Q".ord
|
||||||
|
message << [message_size].pack('l>').unpack('CCCC') # 4
|
||||||
|
message << query.split('').map(&:ord) # 2, 11
|
||||||
|
message << 0 # 1, 12
|
||||||
|
message.flatten!
|
||||||
|
@socket.write(message.flatten.pack('C*'))
|
||||||
|
log "[F] Sent Q message (#{query})"
|
||||||
|
end
|
||||||
|
|
||||||
|
def send_parse_message(query)
|
||||||
|
query_size = query.length
|
||||||
|
message_size = 2 + 2 + 4 + query_size
|
||||||
|
message = []
|
||||||
|
message << "P".ord
|
||||||
|
message << [message_size].pack('l>').unpack('CCCC') # 4
|
||||||
|
message << 0 # unnamed statement
|
||||||
|
message << query.split('').map(&:ord) # 2, 11
|
||||||
|
message << 0 # 1, 12
|
||||||
|
message << [0, 0]
|
||||||
|
message.flatten!
|
||||||
|
@socket.write(message.flatten.pack('C*'))
|
||||||
|
log "[F] Sent P message (#{query})"
|
||||||
|
end
|
||||||
|
|
||||||
|
def send_bind_message
|
||||||
|
message = []
|
||||||
|
message << "B".ord
|
||||||
|
message << [12].pack('l>').unpack('CCCC') # 4
|
||||||
|
message << 0 # unnamed statement
|
||||||
|
message << 0 # unnamed statement
|
||||||
|
message << [0, 0] # 2
|
||||||
|
message << [0, 0] # 2
|
||||||
|
message << [0, 0] # 2
|
||||||
|
message.flatten!
|
||||||
|
@socket.write(message.flatten.pack('C*'))
|
||||||
|
log "[F] Sent B message"
|
||||||
|
end
|
||||||
|
|
||||||
|
def send_describe_message(mode)
|
||||||
|
message = []
|
||||||
|
message << "D".ord
|
||||||
|
message << [6].pack('l>').unpack('CCCC') # 4
|
||||||
|
message << mode.ord
|
||||||
|
message << 0 # unnamed statement
|
||||||
|
message.flatten!
|
||||||
|
@socket.write(message.flatten.pack('C*'))
|
||||||
|
log "[F] Sent D message"
|
||||||
|
end
|
||||||
|
|
||||||
|
def send_execute_message(limit=0)
|
||||||
|
message = []
|
||||||
|
message << "E".ord
|
||||||
|
message << [9].pack('l>').unpack('CCCC') # 4
|
||||||
|
message << 0 # unnamed statement
|
||||||
|
message << [limit].pack('l>').unpack('CCCC') # 4
|
||||||
|
message.flatten!
|
||||||
|
@socket.write(message.flatten.pack('C*'))
|
||||||
|
log "[F] Sent E message"
|
||||||
|
end
|
||||||
|
|
||||||
|
def send_sync_message
|
||||||
|
message = []
|
||||||
|
message << "S".ord
|
||||||
|
message << [4].pack('l>').unpack('CCCC') # 4
|
||||||
|
message.flatten!
|
||||||
|
@socket.write(message.flatten.pack('C*'))
|
||||||
|
log "[F] Sent S message"
|
||||||
|
end
|
||||||
|
|
||||||
|
def send_copydone_message
|
||||||
|
message = []
|
||||||
|
message << "c".ord
|
||||||
|
message << [4].pack('l>').unpack('CCCC') # 4
|
||||||
|
message.flatten!
|
||||||
|
@socket.write(message.flatten.pack('C*'))
|
||||||
|
log "[F] Sent c message"
|
||||||
|
end
|
||||||
|
|
||||||
|
def send_copyfail_message
|
||||||
|
message = []
|
||||||
|
message << "f".ord
|
||||||
|
message << [5].pack('l>').unpack('CCCC') # 4
|
||||||
|
message << 0
|
||||||
|
message.flatten!
|
||||||
|
@socket.write(message.flatten.pack('C*'))
|
||||||
|
log "[F] Sent f message"
|
||||||
|
end
|
||||||
|
|
||||||
|
def send_flush_message
|
||||||
|
message = []
|
||||||
|
message << "H".ord
|
||||||
|
message << [4].pack('l>').unpack('CCCC') # 4
|
||||||
|
message.flatten!
|
||||||
|
@socket.write(message.flatten.pack('C*'))
|
||||||
|
log "[F] Sent H message"
|
||||||
|
end
|
||||||
|
|
||||||
def read_from_server()
|
def read_from_server()
|
||||||
output_messages = []
|
output_messages = []
|
||||||
retry_count = 0
|
retry_count = 0
|
||||||
|
|||||||
@@ -16,14 +16,10 @@ describe "Portocol handling" do
|
|||||||
end
|
end
|
||||||
|
|
||||||
def run_comparison(sequence, socket_a, socket_b)
|
def run_comparison(sequence, socket_a, socket_b)
|
||||||
sequence.each do |msg|
|
sequence.each do |msg, *args|
|
||||||
if msg.is_a?(Symbol)
|
socket_a.send(msg, *args)
|
||||||
socket_a.send(msg)
|
socket_b.send(msg, *args)
|
||||||
socket_b.send(msg)
|
|
||||||
else
|
|
||||||
socket_a.send_message(msg)
|
|
||||||
socket_b.send_message(msg)
|
|
||||||
end
|
|
||||||
compare_messages(
|
compare_messages(
|
||||||
socket_a.read_from_server,
|
socket_a.read_from_server,
|
||||||
socket_b.read_from_server
|
socket_b.read_from_server
|
||||||
@@ -87,9 +83,9 @@ describe "Portocol handling" do
|
|||||||
|
|
||||||
context "Cancel Query" do
|
context "Cancel Query" do
|
||||||
let(:sequence) {
|
let(:sequence) {
|
||||||
[
|
[
|
||||||
SimpleQueryMessage.new("SELECT pg_sleep(5)"),
|
[:send_query_message, "SELECT pg_sleep(5)"],
|
||||||
:cancel_query
|
[:cancel_query]
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -99,12 +95,12 @@ describe "Portocol handling" do
|
|||||||
xcontext "Simple query after parse" do
|
xcontext "Simple query after parse" do
|
||||||
let(:sequence) {
|
let(:sequence) {
|
||||||
[
|
[
|
||||||
ParseMessage.new("", "SELECT 5", []),
|
[:send_parse_message, "SELECT 5"],
|
||||||
SimpleQueryMessage.new("SELECT 1"),
|
[:send_query_message, "SELECT 1"],
|
||||||
BindMessage.new("", "", [], [], [0]),
|
[:send_bind_message],
|
||||||
DescribeMessage.new("P", ""),
|
[:send_describe_message, "P"],
|
||||||
ExecuteMessage.new("", 1),
|
[:send_execute_message],
|
||||||
SyncMessage.new
|
[:send_sync_message],
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,8 +111,8 @@ describe "Portocol handling" do
|
|||||||
xcontext "Flush message" do
|
xcontext "Flush message" do
|
||||||
let(:sequence) {
|
let(:sequence) {
|
||||||
[
|
[
|
||||||
ParseMessage.new("", "SELECT 1", []),
|
[:send_parse_message, "SELECT 1"],
|
||||||
FlushMessage.new
|
[:send_flush_message]
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -126,7 +122,9 @@ describe "Portocol handling" do
|
|||||||
|
|
||||||
xcontext "Bind without parse" do
|
xcontext "Bind without parse" do
|
||||||
let(:sequence) {
|
let(:sequence) {
|
||||||
[BindMessage.new("", "", [], [], [0])]
|
[
|
||||||
|
[:send_bind_message]
|
||||||
|
]
|
||||||
}
|
}
|
||||||
# This is known to fail.
|
# This is known to fail.
|
||||||
# Server responds immediately, Proxy buffers the message
|
# Server responds immediately, Proxy buffers the message
|
||||||
@@ -135,155 +133,23 @@ describe "Portocol handling" do
|
|||||||
|
|
||||||
context "Simple message" do
|
context "Simple message" do
|
||||||
let(:sequence) {
|
let(:sequence) {
|
||||||
[SimpleQueryMessage.new("SELECT 1")]
|
[[:send_query_message, "SELECT 1"]]
|
||||||
}
|
}
|
||||||
|
|
||||||
it_behaves_like "at parity with database"
|
it_behaves_like "at parity with database"
|
||||||
end
|
end
|
||||||
|
|
||||||
10.times do |i|
|
|
||||||
context "Extended protocol" do
|
|
||||||
let(:sequence) {
|
|
||||||
[
|
|
||||||
ParseMessage.new("", "SELECT 1", []),
|
|
||||||
BindMessage.new("", "", [], [], [0]),
|
|
||||||
DescribeMessage.new("S", ""),
|
|
||||||
ExecuteMessage.new("", 1),
|
|
||||||
SyncMessage.new
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
it_behaves_like "at parity with database"
|
context "Extended protocol" do
|
||||||
end
|
let(:sequence) {
|
||||||
|
[
|
||||||
|
[:send_parse_message, "SELECT 1"],
|
||||||
|
[:send_bind_message],
|
||||||
|
[:send_describe_message, "P"],
|
||||||
|
[:send_execute_message],
|
||||||
|
[:send_sync_message],
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
it_behaves_like "at parity with database"
|
||||||
end
|
end
|
||||||
|
end
|
||||||
describe "Protocol-level prepared statements" do
|
|
||||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 1, "transaction") }
|
|
||||||
before do
|
|
||||||
q_sock = PostgresSocket.new('localhost', processes.pgcat.port)
|
|
||||||
q_sock.send_startup_message("sharding_user", "sharded_db", "sharding_user")
|
|
||||||
table_query = "CREATE TABLE IF NOT EXISTS employees (employee_id SERIAL PRIMARY KEY, salary NUMERIC(10, 2) CHECK (salary > 0));"
|
|
||||||
q_sock.send_message(SimpleQueryMessage.new(table_query))
|
|
||||||
q_sock.close
|
|
||||||
|
|
||||||
current_configs = processes.pgcat.current_config
|
|
||||||
current_configs["pools"]["sharded_db"]["prepared_statements_cache_size"] = 500
|
|
||||||
processes.pgcat.update_config(current_configs)
|
|
||||||
processes.pgcat.reload_config
|
|
||||||
end
|
|
||||||
after do
|
|
||||||
q_sock = PostgresSocket.new('localhost', processes.pgcat.port)
|
|
||||||
q_sock.send_startup_message("sharding_user", "sharded_db", "sharding_user")
|
|
||||||
table_query = "DROP TABLE IF EXISTS employees;"
|
|
||||||
q_sock.send_message(SimpleQueryMessage.new(table_query))
|
|
||||||
q_sock.close
|
|
||||||
end
|
|
||||||
|
|
||||||
context "When unnamed prepared statements are used" do
|
|
||||||
it "does not cache them" do
|
|
||||||
socket = PostgresSocket.new('localhost', processes.pgcat.port)
|
|
||||||
socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
|
|
||||||
|
|
||||||
socket.send_message(SimpleQueryMessage.new("DISCARD ALL"))
|
|
||||||
socket.read_from_server
|
|
||||||
|
|
||||||
10.times do |i|
|
|
||||||
socket.send_message(ParseMessage.new("", "SELECT #{i}", []))
|
|
||||||
socket.send_message(BindMessage.new("", "", [], [], [0]))
|
|
||||||
socket.send_message(DescribeMessage.new("S", ""))
|
|
||||||
socket.send_message(ExecuteMessage.new("", 1))
|
|
||||||
socket.send_message(SyncMessage.new)
|
|
||||||
socket.read_from_server
|
|
||||||
end
|
|
||||||
|
|
||||||
socket.send_message(SimpleQueryMessage.new("SELECT name, statement, prepare_time, parameter_types FROM pg_prepared_statements"))
|
|
||||||
result = socket.read_from_server
|
|
||||||
number_of_saved_statements = result.count { |m| m[:code] == 'D' }
|
|
||||||
expect(number_of_saved_statements).to eq(0)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "When named prepared statements are used" do
|
|
||||||
it "caches them" do
|
|
||||||
socket = PostgresSocket.new('localhost', processes.pgcat.port)
|
|
||||||
socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
|
|
||||||
|
|
||||||
socket.send_message(SimpleQueryMessage.new("DISCARD ALL"))
|
|
||||||
socket.read_from_server
|
|
||||||
|
|
||||||
3.times do
|
|
||||||
socket.send_message(ParseMessage.new("my_query", "SELECT * FROM employees WHERE employee_id in ($1,$2,$3)", [0,0,0]))
|
|
||||||
socket.send_message(BindMessage.new("", "my_query", [0,0,0], [0,0,0].map(&:to_s), [0,0,0,0,0,0]))
|
|
||||||
socket.send_message(SyncMessage.new)
|
|
||||||
socket.read_from_server
|
|
||||||
end
|
|
||||||
|
|
||||||
3.times do
|
|
||||||
socket.send_message(ParseMessage.new("my_other_query", "SELECT * FROM employees WHERE salary in ($1,$2,$3)", [0,0,0]))
|
|
||||||
socket.send_message(BindMessage.new("", "my_other_query", [0,0,0], [0,0,0].map(&:to_s), [0,0,0,0,0,0]))
|
|
||||||
socket.send_message(SyncMessage.new)
|
|
||||||
socket.read_from_server
|
|
||||||
end
|
|
||||||
|
|
||||||
socket.send_message(SimpleQueryMessage.new("SELECT name, statement, prepare_time, parameter_types FROM pg_prepared_statements"))
|
|
||||||
result = socket.read_from_server
|
|
||||||
number_of_saved_statements = result.count { |m| m[:code] == 'D' }
|
|
||||||
expect(number_of_saved_statements).to eq(2)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "When DISCARD ALL/DEALLOCATE ALL are called" do
|
|
||||||
it "resets server and client caches" do
|
|
||||||
socket = PostgresSocket.new('localhost', processes.pgcat.port)
|
|
||||||
socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
|
|
||||||
|
|
||||||
20.times do |i|
|
|
||||||
socket.send_message(ParseMessage.new("my_query_#{i}", "SELECT * FROM employees WHERE employee_id in ($1,$2,$3)", [0,0,0]))
|
|
||||||
end
|
|
||||||
|
|
||||||
20.times do |i|
|
|
||||||
socket.send_message(BindMessage.new("", "my_query_#{i}", [0,0,0], [0,0,0].map(&:to_s), [0,0]))
|
|
||||||
end
|
|
||||||
|
|
||||||
socket.send_message(SyncMessage.new)
|
|
||||||
socket.read_from_server
|
|
||||||
|
|
||||||
socket.send_message(SimpleQueryMessage.new("DISCARD ALL"))
|
|
||||||
socket.read_from_server
|
|
||||||
responses = []
|
|
||||||
4.times do |i|
|
|
||||||
socket.send_message(ParseMessage.new("my_query_#{i}", "SELECT * FROM employees WHERE employee_id in ($1,$2,$3)", [0,0,0]))
|
|
||||||
socket.send_message(BindMessage.new("", "my_query_#{i}", [0,0,0], [0,0,0].map(&:to_s), [0,0]))
|
|
||||||
socket.send_message(SyncMessage.new)
|
|
||||||
|
|
||||||
responses += socket.read_from_server
|
|
||||||
end
|
|
||||||
|
|
||||||
errors = responses.select { |message| message[:code] == 'E' }
|
|
||||||
error_message = errors.map { |message| message[:bytes].map(&:chr).join("") }.join("\n")
|
|
||||||
raise StandardError, "Encountered the following errors: #{error_message}" if errors.length > 0
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "Maximum number of bound paramters" do
|
|
||||||
it "does not crash" do
|
|
||||||
test_socket = PostgresSocket.new('localhost', processes.pgcat.port)
|
|
||||||
test_socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
|
|
||||||
|
|
||||||
types = Array.new(65_535) { |i| 0 }
|
|
||||||
|
|
||||||
params = Array.new(65_535) { |i| "$#{i+1}" }.join(",")
|
|
||||||
test_socket.send_message(ParseMessage.new("my_query", "SELECT * FROM employees WHERE employee_id in (#{params})", types))
|
|
||||||
|
|
||||||
test_socket.send_message(BindMessage.new("my_query", "my_query", types, types.map(&:to_s), types))
|
|
||||||
|
|
||||||
test_socket.send_message(SyncMessage.new)
|
|
||||||
|
|
||||||
# If the proxy crashes, this will raise an error
|
|
||||||
expect { test_socket.read_from_server }.to_not raise_error
|
|
||||||
|
|
||||||
test_socket.close
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|||||||
Reference in New Issue
Block a user