Compare commits

..

5 Commits

Author SHA1 Message Date
Mostafa
3fc9e5dec1 Merge branch 'main' of github.com:postgresml/pgcat into mostafa_fix_prepared_stmts 2024-09-03 18:11:32 -05:00
Mostafa
f7c5c0faf9 fix bind 2024-09-01 16:14:44 -05:00
Mostafa
982d03c374 fix syntax 2024-09-01 15:41:33 -05:00
Mostafa
686b7ca7c5 Fixes 2024-09-01 15:31:27 -05:00
Mostafa
7c55bf78fe Add failing tests 2024-09-01 14:39:05 -05:00
25 changed files with 662 additions and 766 deletions

View File

@@ -26,7 +26,6 @@ PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard1 -i
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard2 -i
# Start Toxiproxy
kill -9 $(pgrep toxiproxy) || true
LOG_LEVEL=error toxiproxy-server &
sleep 1
@@ -107,7 +106,7 @@ cd ../..
# These tests will start and stop the pgcat server so it will need to be restarted after the tests
#
pip3 install -r tests/python/requirements.txt
pytest || exit 1
python3 tests/python/tests.py || exit 1
#
@@ -178,6 +177,3 @@ killall pgcat -s SIGINT
# Allow for graceful shutdown
sleep 1
kill -9 $(pgrep toxiproxy)
sleep 1

View File

@@ -22,7 +22,7 @@ jobs:
# Python is required because `ct lint` runs Yamale (https://github.com/23andMe/Yamale) and
# yamllint (https://github.com/adrienverge/yamllint) which require Python
- name: Set up Python
uses: actions/setup-python@v5.1.0
uses: actions/setup-python@v4.1.0
with:
python-version: 3.7
@@ -43,7 +43,7 @@ jobs:
run: ct lint --config ct.yaml
- name: Create kind cluster
uses: helm/kind-action@v1.10.0
uses: helm/kind-action@v1.7.0
if: steps.list-changed.outputs.changed == 'true'
- name: Run chart-testing (install)

View File

@@ -32,7 +32,7 @@ jobs:
version: v3.13.0
- name: Run chart-releaser
uses: helm/chart-releaser-action@a917fd15b20e8b64b94d9158ad54cd6345335584 # v1.6.0
uses: helm/chart-releaser-action@be16258da8010256c6e82849661221415f031968 # v1.5.0
with:
charts_dir: charts
config: cr.yaml

View File

@@ -1,9 +1,6 @@
name: pgcat package (deb)
on:
push:
tags:
- v*
workflow_dispatch:
inputs:
packageVersion:
@@ -19,14 +16,6 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
- name: Set package version
if: github.event_name == 'push' # For push event
run: |
TAG=${{ github.ref_name }}
echo "packageVersion=${TAG#v}" >> "$GITHUB_ENV"
- name: Set package version (manual dispatch)
if: github.event_name == 'workflow_dispatch' # For manual dispatch
run: echo "packageVersion=${{ github.event.inputs.packageVersion }}" >> "$GITHUB_ENV"
- uses: actions-rs/toolchain@v1
with:
toolchain: stable
@@ -50,10 +39,10 @@ jobs:
export ARCH=arm64
fi
bash utilities/deb.sh ${{ env.packageVersion }}
bash utilities/deb.sh ${{ inputs.packageVersion }}
deb-s3 upload \
--lock \
--bucket apt.postgresml.org \
pgcat-${{ env.packageVersion }}-ubuntu22.04-${ARCH}.deb \
pgcat-${{ inputs.packageVersion }}-ubuntu22.04-${ARCH}.deb \
--codename $(lsb_release -cs)

3
.gitignore vendored
View File

@@ -10,5 +10,4 @@ lcov.info
dev/.bash_history
dev/cache
!dev/cache/.keepme
.venv
**/__pycache__
.venv

View File

@@ -7,7 +7,7 @@ Thank you for contributing! Just a few tips here:
3. Performance is important, make sure there are no regressions in your branch vs. `main`.
## How to run the integration tests locally and iterate on them
We have integration tests written in Ruby, Python, Go and Rust.
We have integration tests written in Ruby, Python, Go and Rust.
Below are the steps to run them in a developer-friendly way that allows iterating and quick turnaround.
Hear me out, this should be easy, it will involve opening a shell into a container with all the necessary dependancies available for you and you can modify the test code and immediately rerun your test in the interactive shell.
@@ -21,7 +21,7 @@ Within this test environment you can modify the file in your favorite IDE and re
Once the environment is ready, you can run the tests by running
Ruby: `cd /app/tests/ruby && bundle exec ruby <test_name>.rb --format documentation`
Python: `cd /app/ && pytest`
Python: `cd /app && python3 tests/python/tests.py`
Rust: `cd /app/tests/rust && cargo run`
Go: `cd /app/tests/go && /usr/local/go/bin/go test`

View File

@@ -5,4 +5,4 @@ maintainers:
- name: Wildcard
email: support@w6d.io
appVersion: "1.2.0"
version: 0.2.1
version: 0.2.0

View File

@@ -15,7 +15,6 @@ stringData:
connect_timeout = {{ .Values.configuration.general.connect_timeout }}
idle_timeout = {{ .Values.configuration.general.idle_timeout | int }}
server_lifetime = {{ .Values.configuration.general.server_lifetime | int }}
server_tls = {{ .Values.configuration.general.server_tls }}
idle_client_in_transaction_timeout = {{ .Values.configuration.general.idle_client_in_transaction_timeout | int }}
healthcheck_timeout = {{ .Values.configuration.general.healthcheck_timeout }}
healthcheck_delay = {{ .Values.configuration.general.healthcheck_delay }}
@@ -59,21 +58,11 @@ stringData:
##
[pools.{{ $pool.name | quote }}.users.{{ $index }}]
username = {{ $user.username | quote }}
{{- if $user.password }}
password = {{ $user.password | quote }}
{{- else if and $user.passwordSecret.name $user.passwordSecret.key }}
{{- $secret := (lookup "v1" "Secret" $.Release.Namespace $user.passwordSecret.name) }}
{{- if $secret }}
{{- $password := index $secret.data $user.passwordSecret.key | b64dec }}
password = {{ $password | quote }}
{{- end }}
{{- end }}
pool_size = {{ $user.pool_size }}
statement_timeout = {{ default 0 $user.statement_timeout }}
min_pool_size = {{ default 3 $user.min_pool_size }}
{{- if $user.server_lifetime }}
server_lifetime = {{ $user.server_lifetime }}
{{- end }}
statement_timeout = {{ $user.statement_timeout }}
min_pool_size = 3
server_lifetime = 60000
{{- if and $user.server_username $user.server_password }}
server_username = {{ $user.server_username | quote }}
server_password = {{ $user.server_password | quote }}

View File

@@ -175,9 +175,6 @@ configuration:
# Max connection lifetime before it's closed, even if actively used.
server_lifetime: 86400000 # 24 hours
# Whether to use TLS for server connections or not.
server_tls: false
# How long a client is allowed to be idle while in a transaction (ms).
idle_client_in_transaction_timeout: 0 # milliseconds
@@ -318,9 +315,7 @@ configuration:
# ## Credentials for users that may connect to this cluster
# ## @param users [array]
# ## @param users[0].username Name of the env var (required)
# ## @param users[0].password Value for the env var (required) leave empty to use existing secret see passwordSecret.name and passwordSecret.key
# ## @param users[0].passwordSecret.name Name of the secret containing the password
# ## @param users[0].passwordSecret.key Key in the secret containing the password
# ## @param users[0].password Value for the env var (required)
# ## @param users[0].pool_size Maximum number of server connections that can be established for this user
# ## @param users[0].statement_timeout Maximum query duration. Dangerous, but protects against DBs that died in a non-obvious way.
# users: []

View File

@@ -7,7 +7,3 @@ systemctl enable pgcat
if ! id pgcat 2> /dev/null; then
useradd -s /usr/bin/false pgcat
fi
if [ -f /etc/pgcat.toml ]; then
systemctl start pgcat
fi

View File

@@ -1,4 +1,3 @@
use crate::config::AuthType;
use crate::errors::Error;
use crate::pool::ConnectionPool;
use crate::server::Server;
@@ -72,7 +71,6 @@ impl AuthPassthrough {
pub async fn fetch_hash(&self, address: &crate::config::Address) -> Result<String, Error> {
let auth_user = crate::config::User {
username: self.user.clone(),
auth_type: AuthType::MD5,
password: Some(self.password.clone()),
server_username: None,
server_password: None,

View File

@@ -14,9 +14,7 @@ use tokio::sync::mpsc::Sender;
use crate::admin::{generate_server_parameters_for_admin, handle_admin};
use crate::auth_passthrough::refetch_auth_hash;
use crate::config::{
get_config, get_idle_client_in_transaction_timeout, Address, AuthType, PoolMode,
};
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
use crate::constants::*;
use crate::messages::*;
use crate::plugins::PluginOutput;
@@ -465,8 +463,8 @@ where
.count()
== 1;
// Kick any client that's not admin while we're in admin-only mode.
if !admin && admin_only {
// Kick any client that's not admin while we're in admin-only mode.
debug!(
"Rejecting non-admin connection to {} when in admin only mode",
pool_name
@@ -483,76 +481,72 @@ where
let process_id: i32 = rand::random();
let secret_key: i32 = rand::random();
// Perform MD5 authentication.
// TODO: Add SASL support.
let salt = md5_challenge(&mut write).await?;
let code = match read.read_u8().await {
Ok(p) => p,
Err(_) => {
return Err(Error::ClientSocketError(
"password code".into(),
client_identifier,
))
}
};
// PasswordMessage
if code as char != 'p' {
return Err(Error::ProtocolSyncError(format!(
"Expected p, got {}",
code as char
)));
}
let len = match read.read_i32().await {
Ok(len) => len,
Err(_) => {
return Err(Error::ClientSocketError(
"password message length".into(),
client_identifier,
))
}
};
let mut password_response = vec![0u8; (len - 4) as usize];
match read.read_exact(&mut password_response).await {
Ok(_) => (),
Err(_) => {
return Err(Error::ClientSocketError(
"password message".into(),
client_identifier,
))
}
};
let mut prepared_statements_enabled = false;
// Authenticate admin user.
let (transaction_mode, mut server_parameters) = if admin {
let config = get_config();
// TODO: Add SASL support.
// Perform MD5 authentication.
match config.general.admin_auth_type {
AuthType::Trust => (),
AuthType::MD5 => {
let salt = md5_challenge(&mut write).await?;
let code = match read.read_u8().await {
Ok(p) => p,
Err(_) => {
return Err(Error::ClientSocketError(
"password code".into(),
client_identifier,
))
}
};
// Compare server and client hashes.
let password_hash = md5_hash_password(
&config.general.admin_username,
&config.general.admin_password,
&salt,
);
// PasswordMessage
if code as char != 'p' {
return Err(Error::ProtocolSyncError(format!(
"Expected p, got {}",
code as char
)));
}
if password_hash != password_response {
let error = Error::ClientGeneralError("Invalid password".into(), client_identifier);
let len = match read.read_i32().await {
Ok(len) => len,
Err(_) => {
return Err(Error::ClientSocketError(
"password message length".into(),
client_identifier,
))
}
};
warn!("{}", error);
wrong_password(&mut write, username).await?;
let mut password_response = vec![0u8; (len - 4) as usize];
match read.read_exact(&mut password_response).await {
Ok(_) => (),
Err(_) => {
return Err(Error::ClientSocketError(
"password message".into(),
client_identifier,
))
}
};
// Compare server and client hashes.
let password_hash = md5_hash_password(
&config.general.admin_username,
&config.general.admin_password,
&salt,
);
if password_hash != password_response {
let error =
Error::ClientGeneralError("Invalid password".into(), client_identifier);
warn!("{}", error);
wrong_password(&mut write, username).await?;
return Err(error);
}
}
return Err(error);
}
(false, generate_server_parameters_for_admin())
}
// Authenticate normal user.
@@ -579,143 +573,92 @@ where
// Obtain the hash to compare, we give preference to that written in cleartext in config
// if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained
// when the pool was created. If there is no hash there, we try to fetch it one more time.
match pool.settings.user.auth_type {
AuthType::Trust => (),
AuthType::MD5 => {
// Perform MD5 authentication.
// TODO: Add SASL support.
let salt = md5_challenge(&mut write).await?;
let password_hash = if let Some(password) = &pool.settings.user.password {
Some(md5_hash_password(username, password, &salt))
} else {
if !get_config().is_auth_query_configured() {
wrong_password(&mut write, username).await?;
return Err(Error::ClientAuthImpossible(username.into()));
}
let code = match read.read_u8().await {
Ok(p) => p,
Err(_) => {
return Err(Error::ClientSocketError(
"password code".into(),
client_identifier,
))
}
};
let mut hash = (*pool.auth_hash.read()).clone();
// PasswordMessage
if code as char != 'p' {
return Err(Error::ProtocolSyncError(format!(
"Expected p, got {}",
code as char
)));
}
if hash.is_none() {
warn!(
"Query auth configured \
but no hash password found \
for pool {}. Will try to refetch it.",
pool_name
);
let len = match read.read_i32().await {
Ok(len) => len,
Err(_) => {
return Err(Error::ClientSocketError(
"password message length".into(),
client_identifier,
))
}
};
let mut password_response = vec![0u8; (len - 4) as usize];
match read.read_exact(&mut password_response).await {
Ok(_) => (),
Err(_) => {
return Err(Error::ClientSocketError(
"password message".into(),
client_identifier,
))
}
};
let password_hash = if let Some(password) = &pool.settings.user.password {
Some(md5_hash_password(username, password, &salt))
} else {
if !get_config().is_auth_query_configured() {
wrong_password(&mut write, username).await?;
return Err(Error::ClientAuthImpossible(username.into()));
}
let mut hash = (*pool.auth_hash.read()).clone();
if hash.is_none() {
warn!(
"Query auth configured \
but no hash password found \
for pool {}. Will try to refetch it.",
pool_name
);
match refetch_auth_hash(&pool).await {
Ok(fetched_hash) => {
warn!(
"Password for {}, obtained. Updating.",
client_identifier
);
{
let mut pool_auth_hash = pool.auth_hash.write();
*pool_auth_hash = Some(fetched_hash.clone());
}
hash = Some(fetched_hash);
}
Err(err) => {
wrong_password(&mut write, username).await?;
return Err(Error::ClientAuthPassthroughError(
err.to_string(),
client_identifier,
));
}
}
};
Some(md5_hash_second_pass(&hash.unwrap(), &salt))
};
// Once we have the resulting hash, we compare with what the client gave us.
// If they do not match and auth query is set up, we try to refetch the hash one more time
// to see if the password has changed since the pool was created.
//
// @TODO: we could end up fetching again the same password twice (see above).
if password_hash.unwrap() != password_response {
warn!(
"Invalid password {}, will try to refetch it.",
client_identifier
);
let fetched_hash = match refetch_auth_hash(&pool).await {
Ok(fetched_hash) => fetched_hash,
Err(err) => {
wrong_password(&mut write, username).await?;
return Err(err);
}
};
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
// Ok password changed in server an auth is possible.
if new_password_hash == password_response {
warn!(
"Password for {}, changed in server. Updating.",
client_identifier
);
match refetch_auth_hash(&pool).await {
Ok(fetched_hash) => {
warn!("Password for {}, obtained. Updating.", client_identifier);
{
let mut pool_auth_hash = pool.auth_hash.write();
*pool_auth_hash = Some(fetched_hash);
*pool_auth_hash = Some(fetched_hash.clone());
}
} else {
hash = Some(fetched_hash);
}
Err(err) => {
wrong_password(&mut write, username).await?;
return Err(Error::ClientGeneralError(
"Invalid password".into(),
return Err(Error::ClientAuthPassthroughError(
err.to_string(),
client_identifier,
));
}
}
};
Some(md5_hash_second_pass(&hash.unwrap(), &salt))
};
// Once we have the resulting hash, we compare with what the client gave us.
// If they do not match and auth query is set up, we try to refetch the hash one more time
// to see if the password has changed since the pool was created.
//
// @TODO: we could end up fetching again the same password twice (see above).
if password_hash.unwrap() != password_response {
warn!(
"Invalid password {}, will try to refetch it.",
client_identifier
);
let fetched_hash = match refetch_auth_hash(&pool).await {
Ok(fetched_hash) => fetched_hash,
Err(err) => {
wrong_password(&mut write, username).await?;
return Err(err);
}
};
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
// Ok password changed in server an auth is possible.
if new_password_hash == password_response {
warn!(
"Password for {}, changed in server. Updating.",
client_identifier
);
{
let mut pool_auth_hash = pool.auth_hash.write();
*pool_auth_hash = Some(fetched_hash);
}
} else {
wrong_password(&mut write, username).await?;
return Err(Error::ClientGeneralError(
"Invalid password".into(),
client_identifier,
));
}
}
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
prepared_statements_enabled =
transaction_mode && pool.prepared_statement_cache.is_some();
@@ -1786,14 +1729,13 @@ where
/// and also the pool's statement cache. Add it to extended protocol data.
fn buffer_parse(&mut self, message: BytesMut, pool: &ConnectionPool) -> Result<(), Error> {
// Avoid parsing if prepared statements not enabled
if !self.prepared_statements_enabled {
let client_given_name = Parse::get_name(&message)?;
if !self.prepared_statements_enabled || client_given_name.is_empty() {
debug!("Anonymous parse message");
self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_parse(message, None));
return Ok(());
}
let client_given_name = Parse::get_name(&message)?;
let parse: Parse = (&message).try_into()?;
// Compute the hash of the parse statement
@@ -1831,15 +1773,14 @@ where
/// saved in the client cache.
async fn buffer_bind(&mut self, message: BytesMut) -> Result<(), Error> {
// Avoid parsing if prepared statements not enabled
if !self.prepared_statements_enabled {
let client_given_name = Bind::get_name(&message)?;
if !self.prepared_statements_enabled || client_given_name.is_empty() {
debug!("Anonymous bind message");
self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_bind(message, None));
return Ok(());
}
let client_given_name = Bind::get_name(&message)?;
match self.prepared_statements.get(&client_given_name) {
Some((rewritten_parse, _)) => {
let message = Bind::rename(message, &rewritten_parse.name)?;
@@ -1891,7 +1832,8 @@ where
}
let describe: Describe = (&message).try_into()?;
if describe.target == 'P' {
let client_given_name = describe.statement_name.clone();
if describe.target == 'P' || client_given_name.is_empty() {
debug!("Portal describe message");
self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_describe(message, None));
@@ -1899,8 +1841,6 @@ where
return Ok(());
}
let client_given_name = describe.statement_name.clone();
match self.prepared_statements.get(&client_given_name) {
Some((rewritten_parse, _)) => {
let describe = describe.rename(&rewritten_parse.name);

View File

@@ -208,9 +208,6 @@ impl Address {
pub struct User {
pub username: String,
pub password: Option<String>,
#[serde(default = "User::default_auth_type")]
pub auth_type: AuthType,
pub server_username: Option<String>,
pub server_password: Option<String>,
pub pool_size: u32,
@@ -228,7 +225,6 @@ impl Default for User {
User {
username: String::from("postgres"),
password: None,
auth_type: AuthType::MD5,
server_username: None,
server_password: None,
pool_size: 15,
@@ -243,10 +239,6 @@ impl Default for User {
}
impl User {
pub fn default_auth_type() -> AuthType {
AuthType::MD5
}
fn validate(&self) -> Result<(), Error> {
if let Some(min_pool_size) = self.min_pool_size {
if min_pool_size > self.pool_size {
@@ -342,9 +334,6 @@ pub struct General {
pub admin_username: String,
pub admin_password: String,
#[serde(default = "General::default_admin_auth_type")]
pub admin_auth_type: AuthType,
#[serde(default = "General::default_validate_config")]
pub validate_config: bool,
@@ -359,10 +348,6 @@ impl General {
"0.0.0.0".into()
}
pub fn default_admin_auth_type() -> AuthType {
AuthType::MD5
}
pub fn default_port() -> u16 {
5432
}
@@ -471,7 +456,6 @@ impl Default for General {
verify_server_certificate: false,
admin_username: String::from("admin"),
admin_password: String::from("admin"),
admin_auth_type: AuthType::MD5,
validate_config: true,
auth_query: None,
auth_query_user: None,
@@ -492,15 +476,6 @@ pub enum PoolMode {
Session,
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Copy, Hash)]
pub enum AuthType {
#[serde(alias = "trust", alias = "Trust")]
Trust,
#[serde(alias = "md5", alias = "MD5")]
MD5,
}
impl std::fmt::Display for PoolMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {

View File

@@ -821,10 +821,10 @@ impl ExtendedProtocolData {
pub struct Parse {
code: char,
#[allow(dead_code)]
len: i32,
len: u32,
pub name: String,
query: String,
num_params: i16,
num_params: u16,
param_types: Vec<i32>,
}
@@ -834,12 +834,11 @@ impl TryFrom<&BytesMut> for Parse {
fn try_from(buf: &BytesMut) -> Result<Parse, Error> {
let mut cursor = Cursor::new(buf);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let len = cursor.get_u32();
let name = cursor.read_string()?;
let query = cursor.read_string()?;
let num_params = cursor.get_i16();
let num_params = cursor.get_u16();
let mut param_types = Vec::new();
for _ in 0..num_params {
param_types.push(cursor.get_i32());
}
@@ -875,10 +874,10 @@ impl TryFrom<Parse> for BytesMut {
+ 4 * parse.num_params as usize;
bytes.put_u8(parse.code as u8);
bytes.put_i32(len as i32);
bytes.put_u32(len as u32);
bytes.put_slice(name);
bytes.put_slice(query);
bytes.put_i16(parse.num_params);
bytes.put_u16(parse.num_params);
for param in parse.param_types {
bytes.put_i32(param);
}
@@ -945,14 +944,14 @@ impl Parse {
pub struct Bind {
code: char,
#[allow(dead_code)]
len: i64,
len: u64,
portal: String,
pub prepared_statement: String,
num_param_format_codes: i16,
num_param_format_codes: u16,
param_format_codes: Vec<i16>,
num_param_values: i16,
num_param_values: u16,
param_values: Vec<(i32, BytesMut)>,
num_result_column_format_codes: i16,
num_result_column_format_codes: u16,
result_columns_format_codes: Vec<i16>,
}
@@ -962,17 +961,17 @@ impl TryFrom<&BytesMut> for Bind {
fn try_from(buf: &BytesMut) -> Result<Bind, Error> {
let mut cursor = Cursor::new(buf);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let len = cursor.get_u32();
let portal = cursor.read_string()?;
let prepared_statement = cursor.read_string()?;
let num_param_format_codes = cursor.get_i16();
let num_param_format_codes = cursor.get_u16();
let mut param_format_codes = Vec::new();
for _ in 0..num_param_format_codes {
param_format_codes.push(cursor.get_i16());
}
let num_param_values = cursor.get_i16();
let num_param_values = cursor.get_u16();
let mut param_values = Vec::new();
for _ in 0..num_param_values {
@@ -994,7 +993,7 @@ impl TryFrom<&BytesMut> for Bind {
}
}
let num_result_column_format_codes = cursor.get_i16();
let num_result_column_format_codes = cursor.get_u16();
let mut result_columns_format_codes = Vec::new();
for _ in 0..num_result_column_format_codes {
@@ -1003,7 +1002,7 @@ impl TryFrom<&BytesMut> for Bind {
Ok(Bind {
code,
len: len as i64,
len: len as u64,
portal,
prepared_statement,
num_param_format_codes,
@@ -1042,19 +1041,19 @@ impl TryFrom<Bind> for BytesMut {
len += 2 * bind.num_result_column_format_codes as usize;
bytes.put_u8(bind.code as u8);
bytes.put_i32(len as i32);
bytes.put_u32(len as u32);
bytes.put_slice(portal);
bytes.put_slice(prepared_statement);
bytes.put_i16(bind.num_param_format_codes);
bytes.put_u16(bind.num_param_format_codes);
for param_format_code in bind.param_format_codes {
bytes.put_i16(param_format_code);
}
bytes.put_i16(bind.num_param_values);
bytes.put_u16(bind.num_param_values);
for (param_len, param) in bind.param_values {
bytes.put_i32(param_len);
bytes.put_slice(&param);
}
bytes.put_i16(bind.num_result_column_format_codes);
bytes.put_u16(bind.num_result_column_format_codes);
for result_column_format_code in bind.result_columns_format_codes {
bytes.put_i16(result_column_format_code);
}
@@ -1068,7 +1067,7 @@ impl Bind {
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
let mut cursor = Cursor::new(buf);
// Skip the code and length
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
cursor.advance(mem::size_of::<u8>() + mem::size_of::<u32>());
cursor.read_string()?;
cursor.read_string()
}
@@ -1078,17 +1077,17 @@ impl Bind {
let mut cursor = Cursor::new(&buf);
// Read basic data from the cursor
let code = cursor.get_u8();
let current_len = cursor.get_i32();
let current_len = cursor.get_u32();
let portal = cursor.read_string()?;
let prepared_statement = cursor.read_string()?;
// Calculate new length
let new_len = current_len + new_name.len() as i32 - prepared_statement.len() as i32;
let new_len = current_len + new_name.len() as u32 - prepared_statement.len() as u32;
// Begin building the response buffer
let mut response_buf = BytesMut::with_capacity(new_len as usize + 1);
response_buf.put_u8(code);
response_buf.put_i32(new_len);
response_buf.put_u32(new_len);
// Put the portal and new name into the buffer
// Note: panic if the provided string contains null byte
@@ -1112,7 +1111,7 @@ pub struct Describe {
code: char,
#[allow(dead_code)]
len: i32,
len: u32,
pub target: char,
pub statement_name: String,
}
@@ -1123,7 +1122,7 @@ impl TryFrom<&BytesMut> for Describe {
fn try_from(bytes: &BytesMut) -> Result<Describe, Error> {
let mut cursor = Cursor::new(bytes);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let len = cursor.get_u32();
let target = cursor.get_u8() as char;
let statement_name = cursor.read_string()?;
@@ -1146,7 +1145,7 @@ impl TryFrom<Describe> for BytesMut {
let len = 4 + 1 + statement_name.len();
bytes.put_u8(describe.code as u8);
bytes.put_i32(len as i32);
bytes.put_u32(len as u32);
bytes.put_u8(describe.target as u8);
bytes.put_slice(statement_name);

View File

@@ -200,17 +200,18 @@ struct PrometheusMetric<Value: fmt::Display> {
impl<Value: fmt::Display> fmt::Display for PrometheusMetric<Value> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut sorted_labels: Vec<_> = self.labels.iter().collect();
sorted_labels.sort_by_key(|&(key, _)| key);
let formatted_labels = sorted_labels
let formatted_labels = self
.labels
.iter()
.map(|(key, value)| format!("{}=\"{}\"", key, value))
.collect::<Vec<_>>()
.join(",");
write!(
f,
"{name}{{{formatted_labels}}} {value}",
"# HELP {name} {help}\n# TYPE {name} {ty}\n{name}{{{formatted_labels}}} {value}\n",
name = format_args!("pgcat_{}", self.name),
help = self.help,
ty = self.ty,
formatted_labels = formatted_labels,
value = self.value
)
@@ -246,7 +247,7 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
labels.insert("pool", address.pool_name.clone());
labels.insert("index", address.address_index.to_string());
labels.insert("database", address.database.to_string());
labels.insert("username", address.username.clone());
labels.insert("user", address.username.clone());
Self::from_name(&format!("databases_{}", name), value, labels)
}
@@ -263,8 +264,7 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
labels.insert("pool", address.pool_name.clone());
labels.insert("index", address.address_index.to_string());
labels.insert("database", address.database.to_string());
labels.insert("username", address.username.clone());
labels.insert("user", address.username.clone());
Self::from_name(&format!("servers_{}", name), value, labels)
}
@@ -276,7 +276,7 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
labels.insert("role", address.role.to_string());
labels.insert("index", address.address_index.to_string());
labels.insert("database", address.database.to_string());
labels.insert("username", address.username.clone());
labels.insert("user", address.username.clone());
Self::from_name(&format!("stats_{}", name), value, labels)
}
@@ -288,15 +288,6 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
Self::from_name(&format!("pools_{}", name), value, labels)
}
fn get_header(&self) -> String {
format!(
"\n# HELP {name} {help}\n# TYPE {name} {ty}",
name = format_args!("pgcat_{}", self.name),
help = self.help,
ty = self.ty,
)
}
}
async fn prometheus_stats(
@@ -322,7 +313,6 @@ async fn prometheus_stats(
// Adds metrics shown in a SHOW STATS admin command.
fn push_address_stats(lines: &mut Vec<String>) {
let mut grouped_metrics: HashMap<String, Vec<PrometheusMetric<u64>>> = HashMap::new();
for (_, pool) in get_all_pools() {
for shard in 0..pool.shards() {
for server in 0..pool.servers(shard) {
@@ -332,10 +322,7 @@ fn push_address_stats(lines: &mut Vec<String>) {
if let Some(prometheus_metric) =
PrometheusMetric::<u64>::from_address(address, &key, value)
{
grouped_metrics
.entry(key)
.or_default()
.push(prometheus_metric);
lines.push(prometheus_metric.to_string());
} else {
debug!("Metric {} not implemented for {}", key, address.name());
}
@@ -343,53 +330,33 @@ fn push_address_stats(lines: &mut Vec<String>) {
}
}
}
for (_key, metrics) in grouped_metrics {
if !metrics.is_empty() {
lines.push(metrics[0].get_header());
for metric in metrics {
lines.push(metric.to_string());
}
}
}
}
// Adds relevant metrics shown in a SHOW POOLS admin command.
fn push_pool_stats(lines: &mut Vec<String>) {
let mut grouped_metrics: HashMap<String, Vec<PrometheusMetric<u64>>> = HashMap::new();
let pool_stats = PoolStats::construct_pool_lookup();
for (pool_id, stats) in pool_stats.iter() {
for (name, value) in stats.clone() {
if let Some(prometheus_metric) =
PrometheusMetric::<u64>::from_pool(pool_id.clone(), &name, value)
{
grouped_metrics
.entry(name)
.or_default()
.push(prometheus_metric);
lines.push(prometheus_metric.to_string());
} else {
debug!("Metric {} not implemented for ({})", name, *pool_id);
}
}
}
for (_key, metrics) in grouped_metrics {
if !metrics.is_empty() {
lines.push(metrics[0].get_header());
for metric in metrics {
lines.push(metric.to_string());
}
}
}
}
// Adds relevant metrics shown in a SHOW DATABASES admin command.
fn push_database_stats(lines: &mut Vec<String>) {
let mut grouped_metrics: HashMap<String, Vec<PrometheusMetric<u32>>> = HashMap::new();
for (_, pool) in get_all_pools() {
let pool_config = pool.settings.clone();
for shard in 0..pool.shards() {
for server in 0..pool.servers(shard) {
let address = pool.address(shard, server);
let pool_state = pool.pool_state(shard, server);
let metrics = vec![
("pool_size", pool_config.user.pool_size),
("current_connections", pool_state.connections),
@@ -398,10 +365,7 @@ fn push_database_stats(lines: &mut Vec<String>) {
if let Some(prometheus_metric) =
PrometheusMetric::<u32>::from_database_info(address, key, value)
{
grouped_metrics
.entry(key.to_string())
.or_default()
.push(prometheus_metric);
lines.push(prometheus_metric.to_string());
} else {
debug!("Metric {} not implemented for {}", key, address.name());
}
@@ -409,14 +373,6 @@ fn push_database_stats(lines: &mut Vec<String>) {
}
}
}
for (_key, metrics) in grouped_metrics {
if !metrics.is_empty() {
lines.push(metrics[0].get_header());
for metric in metrics {
lines.push(metric.to_string());
}
}
}
}
// Adds relevant metrics shown in a SHOW SERVERS admin command.
@@ -449,7 +405,7 @@ fn push_server_stats(lines: &mut Vec<String>) {
crate::stats::ServerState::Idle => entry.idle_count += 1,
}
}
let mut grouped_metrics: HashMap<String, Vec<PrometheusMetric<u64>>> = HashMap::new();
for (_, pool) in get_all_pools() {
for shard in 0..pool.shards() {
for server in 0..pool.servers(shard) {
@@ -472,10 +428,7 @@ fn push_server_stats(lines: &mut Vec<String>) {
if let Some(prometheus_metric) =
PrometheusMetric::<u64>::from_server_info(address, key, value)
{
grouped_metrics
.entry(key.to_string())
.or_default()
.push(prometheus_metric);
lines.push(prometheus_metric.to_string());
} else {
debug!("Metric {} not implemented for {}", key, address.name());
}
@@ -484,14 +437,6 @@ fn push_server_stats(lines: &mut Vec<String>) {
}
}
}
for (_key, metrics) in grouped_metrics {
if !metrics.is_empty() {
lines.push(metrics[0].get_header());
for metric in metrics {
lines.push(metric.to_string());
}
}
}
}
pub async fn start_metric_server(http_addr: SocketAddr) {

View File

@@ -698,7 +698,6 @@ impl Server {
))
}
};
trace!("Error: {}", error_code);
match error_code {
@@ -1013,6 +1012,12 @@ impl Server {
// which can leak between clients. This is a best effort to block bad clients
// from poisoning a transaction-mode pool by setting inappropriate session variables
match command.as_str() {
"DISCARD ALL" => {
self.clear_prepared_statement_cache();
}
"DEALLOCATE ALL" => {
self.clear_prepared_statement_cache();
}
"SET" => {
// We don't detect set statements in transactions
// No great way to differentiate between set and set local
@@ -1132,6 +1137,12 @@ impl Server {
has_it
}
fn clear_prepared_statement_cache(&mut self) {
if let Some(cache) = &mut self.prepared_statement_cache {
cache.clear();
}
}
fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option<String> {
let cache = match &mut self.prepared_statement_cache {
Some(cache) => cache,

View File

@@ -23,11 +23,11 @@ docker compose exec --workdir /app/tests/python main pip3 install -r requirement
echo "Interactive test environment ready"
echo "To run integration tests, you can use the following commands:"
echo -e " ${BLUE}Ruby: ${RED}cd /app/tests/ruby && bundle exec ruby tests.rb --format documentation${RESET}"
echo -e " ${BLUE}Python: ${RED}cd /app/ && pytest ${RESET}"
echo -e " ${BLUE}Python: ${RED}cd /app && python3 tests/python/tests.py${RESET}"
echo -e " ${BLUE}Rust: ${RED}cd /app/tests/rust && cargo run ${RESET}"
echo -e " ${BLUE}Go: ${RED}cd /app/tests/go && /usr/local/go/bin/go test${RESET}"
echo "the source code for tests are directly linked to the source code in the container so you can modify the code and run the tests again"
echo "You can rebuild PgCat from within the container by running"
echo "You can rebuild PgCat from within the container by running"
echo -e " ${GREEN}cargo build${RESET}"
echo "and then run the tests again"
echo "==================================="

View File

@@ -1,3 +1,2 @@
pytest
psycopg2==2.9.3
psutil==5.9.1
psutil==5.9.1

View File

@@ -1,71 +0,0 @@
import utils
import signal
class TestTrustAuth:
@classmethod
def setup_method(cls):
config= """
[general]
host = "0.0.0.0"
port = 6432
admin_username = "admin_user"
admin_password = ""
admin_auth_type = "trust"
[pools.sharded_db.users.0]
username = "sharding_user"
password = "sharding_user"
auth_type = "trust"
pool_size = 10
min_pool_size = 1
pool_mode = "transaction"
[pools.sharded_db.shards.0]
servers = [
[ "127.0.0.1", 5432, "primary" ],
]
database = "shard0"
"""
utils.pgcat_generic_start(config)
@classmethod
def teardown_method(self):
utils.pg_cat_send_signal(signal.SIGTERM)
def test_admin_trust_auth(self):
conn, cur = utils.connect_db_trust(admin=True)
cur.execute("SHOW POOLS")
res = cur.fetchall()
print(res)
utils.cleanup_conn(conn, cur)
def test_normal_trust_auth(self):
conn, cur = utils.connect_db_trust(autocommit=False)
cur.execute("SELECT 1")
res = cur.fetchall()
print(res)
utils.cleanup_conn(conn, cur)
class TestMD5Auth:
@classmethod
def setup_method(cls):
utils.pgcat_start()
@classmethod
def teardown_method(self):
utils.pg_cat_send_signal(signal.SIGTERM)
def test_normal_db_access(self):
conn, cur = utils.connect_db(autocommit=False)
cur.execute("SELECT 1")
res = cur.fetchall()
print(res)
utils.cleanup_conn(conn, cur)
def test_admin_db_access(self):
conn, cur = utils.connect_db(admin=True)
cur.execute("SHOW POOLS")
res = cur.fetchall()
print(res)
utils.cleanup_conn(conn, cur)

View File

@@ -1,12 +1,84 @@
from typing import Tuple
import psycopg2
import psutil
import os
import signal
import time
import psycopg2
import utils
SHUTDOWN_TIMEOUT = 5
PGCAT_HOST = "127.0.0.1"
PGCAT_PORT = "6432"
def pgcat_start():
pg_cat_send_signal(signal.SIGTERM)
os.system("./target/debug/pgcat .circleci/pgcat.toml &")
time.sleep(2)
def pg_cat_send_signal(signal: signal.Signals):
try:
for proc in psutil.process_iter(["pid", "name"]):
if "pgcat" == proc.name():
os.kill(proc.pid, signal)
except Exception as e:
# The process can be gone when we send this signal
print(e)
if signal == signal.SIGTERM:
# Returns 0 if pgcat process exists
time.sleep(2)
if not os.system('pgrep pgcat'):
raise Exception("pgcat not closed after SIGTERM")
def connect_db(
autocommit: bool = True,
admin: bool = False,
) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]:
if admin:
user = "admin_user"
password = "admin_pass"
db = "pgcat"
else:
user = "sharding_user"
password = "sharding_user"
db = "sharded_db"
conn = psycopg2.connect(
f"postgres://{user}:{password}@{PGCAT_HOST}:{PGCAT_PORT}/{db}?application_name=testing_pgcat",
connect_timeout=2,
)
conn.autocommit = autocommit
cur = conn.cursor()
return (conn, cur)
def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.cursor):
cur.close()
conn.close()
def test_normal_db_access():
pgcat_start()
conn, cur = connect_db(autocommit=False)
cur.execute("SELECT 1")
res = cur.fetchall()
print(res)
cleanup_conn(conn, cur)
def test_admin_db_access():
conn, cur = connect_db(admin=True)
cur.execute("SHOW POOLS")
res = cur.fetchall()
print(res)
cleanup_conn(conn, cur)
def test_shutdown_logic():
@@ -14,17 +86,17 @@ def test_shutdown_logic():
# NO ACTIVE QUERIES SIGINT HANDLING
# Start pgcat
utils.pgcat_start()
pgcat_start()
# Create client connection and send query (not in transaction)
conn, cur = utils.connect_db()
conn, cur = connect_db()
cur.execute("BEGIN;")
cur.execute("SELECT 1;")
cur.execute("COMMIT;")
# Send sigint to pgcat
utils.pg_cat_send_signal(signal.SIGINT)
pg_cat_send_signal(signal.SIGINT)
time.sleep(1)
# Check that any new queries fail after sigint since server should close with no active transactions
@@ -36,18 +108,18 @@ def test_shutdown_logic():
# Fail if query execution succeeded
raise Exception("Server not closed after sigint")
utils.cleanup_conn(conn, cur)
utils.pg_cat_send_signal(signal.SIGTERM)
cleanup_conn(conn, cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
# NO ACTIVE QUERIES ADMIN SHUTDOWN COMMAND
# Start pgcat
utils.pgcat_start()
pgcat_start()
# Create client connection and begin transaction
conn, cur = utils.connect_db()
admin_conn, admin_cur = utils.connect_db(admin=True)
conn, cur = connect_db()
admin_conn, admin_cur = connect_db(admin=True)
cur.execute("BEGIN;")
cur.execute("SELECT 1;")
@@ -66,24 +138,24 @@ def test_shutdown_logic():
# Fail if query execution succeeded
raise Exception("Server not closed after sigint")
utils.cleanup_conn(conn, cur)
utils.cleanup_conn(admin_conn, admin_cur)
utils.pg_cat_send_signal(signal.SIGTERM)
cleanup_conn(conn, cur)
cleanup_conn(admin_conn, admin_cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
# HANDLE TRANSACTION WITH SIGINT
# Start pgcat
utils.pgcat_start()
pgcat_start()
# Create client connection and begin transaction
conn, cur = utils.connect_db()
conn, cur = connect_db()
cur.execute("BEGIN;")
cur.execute("SELECT 1;")
# Send sigint to pgcat while still in transaction
utils.pg_cat_send_signal(signal.SIGINT)
pg_cat_send_signal(signal.SIGINT)
time.sleep(1)
# Check that any new queries succeed after sigint since server should still allow transaction to complete
@@ -93,18 +165,18 @@ def test_shutdown_logic():
# Fail if query fails since server closed
raise Exception("Server closed while in transaction", e.pgerror)
utils.cleanup_conn(conn, cur)
utils.pg_cat_send_signal(signal.SIGTERM)
cleanup_conn(conn, cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
# HANDLE TRANSACTION WITH ADMIN SHUTDOWN COMMAND
# Start pgcat
utils.pgcat_start()
pgcat_start()
# Create client connection and begin transaction
conn, cur = utils.connect_db()
admin_conn, admin_cur = utils.connect_db(admin=True)
conn, cur = connect_db()
admin_conn, admin_cur = connect_db(admin=True)
cur.execute("BEGIN;")
cur.execute("SELECT 1;")
@@ -122,30 +194,30 @@ def test_shutdown_logic():
# Fail if query fails since server closed
raise Exception("Server closed while in transaction", e.pgerror)
utils.cleanup_conn(conn, cur)
utils.cleanup_conn(admin_conn, admin_cur)
utils.pg_cat_send_signal(signal.SIGTERM)
cleanup_conn(conn, cur)
cleanup_conn(admin_conn, admin_cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
# NO NEW NON-ADMIN CONNECTIONS DURING SHUTDOWN
# Start pgcat
utils.pgcat_start()
pgcat_start()
# Create client connection and begin transaction
transaction_conn, transaction_cur = utils.connect_db()
transaction_conn, transaction_cur = connect_db()
transaction_cur.execute("BEGIN;")
transaction_cur.execute("SELECT 1;")
# Send sigint to pgcat while still in transaction
utils.pg_cat_send_signal(signal.SIGINT)
pg_cat_send_signal(signal.SIGINT)
time.sleep(1)
start = time.perf_counter()
try:
conn, cur = utils.connect_db()
conn, cur = connect_db()
cur.execute("SELECT 1;")
utils.cleanup_conn(conn, cur)
cleanup_conn(conn, cur)
except psycopg2.OperationalError as e:
time_taken = time.perf_counter() - start
if time_taken > 0.1:
@@ -155,49 +227,49 @@ def test_shutdown_logic():
else:
raise Exception("Able connect to database during shutdown")
utils.cleanup_conn(transaction_conn, transaction_cur)
utils.pg_cat_send_signal(signal.SIGTERM)
cleanup_conn(transaction_conn, transaction_cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
# ALLOW NEW ADMIN CONNECTIONS DURING SHUTDOWN
# Start pgcat
utils.pgcat_start()
pgcat_start()
# Create client connection and begin transaction
transaction_conn, transaction_cur = utils.connect_db()
transaction_conn, transaction_cur = connect_db()
transaction_cur.execute("BEGIN;")
transaction_cur.execute("SELECT 1;")
# Send sigint to pgcat while still in transaction
utils.pg_cat_send_signal(signal.SIGINT)
pg_cat_send_signal(signal.SIGINT)
time.sleep(1)
try:
conn, cur = utils.connect_db(admin=True)
conn, cur = connect_db(admin=True)
cur.execute("SHOW DATABASES;")
utils.cleanup_conn(conn, cur)
cleanup_conn(conn, cur)
except psycopg2.OperationalError as e:
raise Exception(e)
utils.cleanup_conn(transaction_conn, transaction_cur)
utils.pg_cat_send_signal(signal.SIGTERM)
cleanup_conn(transaction_conn, transaction_cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
# ADMIN CONNECTIONS CONTINUING TO WORK AFTER SHUTDOWN
# Start pgcat
utils.pgcat_start()
pgcat_start()
# Create client connection and begin transaction
transaction_conn, transaction_cur = utils.connect_db()
transaction_conn, transaction_cur = connect_db()
transaction_cur.execute("BEGIN;")
transaction_cur.execute("SELECT 1;")
admin_conn, admin_cur = utils.connect_db(admin=True)
admin_conn, admin_cur = connect_db(admin=True)
admin_cur.execute("SHOW DATABASES;")
# Send sigint to pgcat while still in transaction
utils.pg_cat_send_signal(signal.SIGINT)
pg_cat_send_signal(signal.SIGINT)
time.sleep(1)
try:
@@ -205,24 +277,24 @@ def test_shutdown_logic():
except psycopg2.OperationalError as e:
raise Exception("Could not execute admin command:", e)
utils.cleanup_conn(transaction_conn, transaction_cur)
utils.cleanup_conn(admin_conn, admin_cur)
utils.pg_cat_send_signal(signal.SIGTERM)
cleanup_conn(transaction_conn, transaction_cur)
cleanup_conn(admin_conn, admin_cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
# HANDLE SHUTDOWN TIMEOUT WITH SIGINT
# Start pgcat
utils.pgcat_start()
pgcat_start()
# Create client connection and begin transaction, which should prevent server shutdown unless shutdown timeout is reached
conn, cur = utils.connect_db()
conn, cur = connect_db()
cur.execute("BEGIN;")
cur.execute("SELECT 1;")
# Send sigint to pgcat while still in transaction
utils.pg_cat_send_signal(signal.SIGINT)
pg_cat_send_signal(signal.SIGINT)
# pgcat shutdown timeout is set to SHUTDOWN_TIMEOUT seconds, so we sleep for SHUTDOWN_TIMEOUT + 1 seconds
time.sleep(SHUTDOWN_TIMEOUT + 1)
@@ -236,7 +308,12 @@ def test_shutdown_logic():
# Fail if query execution succeeded
raise Exception("Server not closed after sigint and expected timeout")
utils.cleanup_conn(conn, cur)
utils.pg_cat_send_signal(signal.SIGTERM)
cleanup_conn(conn, cur)
pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -
test_normal_db_access()
test_admin_db_access()
test_shutdown_logic()

View File

@@ -1,110 +0,0 @@
import os
import signal
import time
from typing import Tuple
import tempfile
import psutil
import psycopg2
PGCAT_HOST = "127.0.0.1"
PGCAT_PORT = "6432"
def _pgcat_start(config_path: str):
pg_cat_send_signal(signal.SIGTERM)
os.system(f"./target/debug/pgcat {config_path} &")
time.sleep(2)
def pgcat_start():
_pgcat_start(config_path='.circleci/pgcat.toml')
def pgcat_generic_start(config: str):
tmp = tempfile.NamedTemporaryFile()
with open(tmp.name, 'w') as f:
f.write(config)
_pgcat_start(config_path=tmp.name)
def glauth_send_signal(signal: signal.Signals):
try:
for proc in psutil.process_iter(["pid", "name"]):
if proc.name() == "glauth":
os.kill(proc.pid, signal)
except Exception as e:
# The process can be gone when we send this signal
print(e)
if signal == signal.SIGTERM:
# Returns 0 if pgcat process exists
time.sleep(2)
if not os.system('pgrep glauth'):
raise Exception("glauth not closed after SIGTERM")
def pg_cat_send_signal(signal: signal.Signals):
try:
for proc in psutil.process_iter(["pid", "name"]):
if "pgcat" == proc.name():
os.kill(proc.pid, signal)
except Exception as e:
# The process can be gone when we send this signal
print(e)
if signal == signal.SIGTERM:
# Returns 0 if pgcat process exists
time.sleep(2)
if not os.system('pgrep pgcat'):
raise Exception("pgcat not closed after SIGTERM")
def connect_db(
autocommit: bool = True,
admin: bool = False,
) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]:
if admin:
user = "admin_user"
password = "admin_pass"
db = "pgcat"
else:
user = "sharding_user"
password = "sharding_user"
db = "sharded_db"
conn = psycopg2.connect(
f"postgres://{user}:{password}@{PGCAT_HOST}:{PGCAT_PORT}/{db}?application_name=testing_pgcat",
connect_timeout=2,
)
conn.autocommit = autocommit
cur = conn.cursor()
return (conn, cur)
def connect_db_trust(
autocommit: bool = True,
admin: bool = False,
) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]:
if admin:
user = "admin_user"
db = "pgcat"
else:
user = "sharding_user"
db = "sharded_db"
conn = psycopg2.connect(
f"postgres://{user}@{PGCAT_HOST}:{PGCAT_PORT}/{db}?application_name=testing_pgcat",
connect_timeout=2,
)
conn.autocommit = autocommit
cur = conn.cursor()
return (conn, cur)
def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.cursor):
cur.close()
conn.close()

View File

@@ -1,33 +1,22 @@
GEM
remote: https://rubygems.org/
specs:
activemodel (7.1.4)
activesupport (= 7.1.4)
activerecord (7.1.4)
activemodel (= 7.1.4)
activesupport (= 7.1.4)
timeout (>= 0.4.0)
activesupport (7.1.4)
base64
bigdecimal
activemodel (7.0.4.1)
activesupport (= 7.0.4.1)
activerecord (7.0.4.1)
activemodel (= 7.0.4.1)
activesupport (= 7.0.4.1)
activesupport (7.0.4.1)
concurrent-ruby (~> 1.0, >= 1.0.2)
connection_pool (>= 2.2.5)
drb
i18n (>= 1.6, < 2)
minitest (>= 5.1)
mutex_m
tzinfo (~> 2.0)
ast (2.4.2)
base64 (0.2.0)
bigdecimal (3.1.8)
concurrent-ruby (1.3.4)
connection_pool (2.4.1)
concurrent-ruby (1.1.10)
diff-lcs (1.5.0)
drb (2.2.1)
i18n (1.14.5)
i18n (1.12.0)
concurrent-ruby (~> 1.0)
minitest (5.25.1)
mutex_m (0.2.0)
minitest (5.17.0)
parallel (1.22.1)
parser (3.1.2.0)
ast (~> 2.4.1)
@@ -35,8 +24,7 @@ GEM
pg (1.3.2)
rainbow (3.1.1)
regexp_parser (2.3.1)
rexml (3.3.6)
strscan
rexml (3.2.5)
rspec (3.11.0)
rspec-core (~> 3.11.0)
rspec-expectations (~> 3.11.0)
@@ -62,12 +50,10 @@ GEM
rubocop-ast (1.17.0)
parser (>= 3.1.1.0)
ruby-progressbar (1.11.0)
strscan (3.1.0)
timeout (0.4.1)
toml (0.3.0)
parslet (>= 1.8.0, < 3.0.0)
toxiproxy (2.0.1)
tzinfo (2.0.6)
tzinfo (2.0.5)
concurrent-ruby (~> 1.0)
unicode-display_width (2.1.0)

View File

@@ -0,0 +1,145 @@
class PostgresMessage
# Base class for common functionality
def encode_string(str)
"#{str}\0" # Encode a string with a null terminator
end
def encode_int16(value)
[value].pack('n') # Encode an Int16
end
def encode_int32(value)
[value].pack('N') # Encode an Int32
end
def message_prefix(type, length)
"#{type}#{encode_int32(length)}" # Message type and length prefix
end
end
class SimpleQueryMessage < PostgresMessage
attr_accessor :query
def initialize(query = "")
@query = query
end
def to_bytes
query_bytes = encode_string(@query)
length = 4 + query_bytes.size # Length includes 4 bytes for length itself
message_prefix('Q', length) + query_bytes
end
end
class ParseMessage < PostgresMessage
attr_accessor :statement_name, :query, :parameter_types
def initialize(statement_name = "", query = "", parameter_types = [])
@statement_name = statement_name
@query = query
@parameter_types = parameter_types
end
def to_bytes
statement_name_bytes = encode_string(@statement_name)
query_bytes = encode_string(@query)
parameter_types_bytes = @parameter_types.pack('N*')
length = 4 + statement_name_bytes.size + query_bytes.size + 2 + parameter_types_bytes.size
message_prefix('P', length) + statement_name_bytes + query_bytes + encode_int16(@parameter_types.size) + parameter_types_bytes
end
end
class BindMessage < PostgresMessage
attr_accessor :portal_name, :statement_name, :parameter_format_codes, :parameters, :result_column_format_codes
def initialize(portal_name = "", statement_name = "", parameter_format_codes = [], parameters = [], result_column_format_codes = [])
@portal_name = portal_name
@statement_name = statement_name
@parameter_format_codes = parameter_format_codes
@parameters = parameters
@result_column_format_codes = result_column_format_codes
end
def to_bytes
portal_name_bytes = encode_string(@portal_name)
statement_name_bytes = encode_string(@statement_name)
parameter_format_codes_bytes = @parameter_format_codes.pack('n*')
parameters_bytes = @parameters.map do |param|
if param.nil?
encode_int32(-1)
else
encode_int32(param.bytesize) + param
end
end.join
result_column_format_codes_bytes = @result_column_format_codes.pack('n*')
length = 4 + portal_name_bytes.size + statement_name_bytes.size + 2 + parameter_format_codes_bytes.size + 2 + parameters_bytes.size + 2 + result_column_format_codes_bytes.size
message_prefix('B', length) + portal_name_bytes + statement_name_bytes + encode_int16(@parameter_format_codes.size) + parameter_format_codes_bytes + encode_int16(@parameters.size) + parameters_bytes + encode_int16(@result_column_format_codes.size) + result_column_format_codes_bytes
end
end
class DescribeMessage < PostgresMessage
attr_accessor :type, :name
def initialize(type = 'S', name = "")
@type = type
@name = name
end
def to_bytes
name_bytes = encode_string(@name)
length = 4 + 1 + name_bytes.size
message_prefix('D', length) + @type + name_bytes
end
end
class ExecuteMessage < PostgresMessage
attr_accessor :portal_name, :max_rows
def initialize(portal_name = "", max_rows = 0)
@portal_name = portal_name
@max_rows = max_rows
end
def to_bytes
portal_name_bytes = encode_string(@portal_name)
length = 4 + portal_name_bytes.size + 4
message_prefix('E', length) + portal_name_bytes + encode_int32(@max_rows)
end
end
class FlushMessage < PostgresMessage
def to_bytes
length = 4
message_prefix('H', length)
end
end
class SyncMessage < PostgresMessage
def to_bytes
length = 4
message_prefix('S', length)
end
end
class CloseMessage < PostgresMessage
attr_accessor :type, :name
def initialize(type = 'S', name = "")
@type = type
@name = name
end
def to_bytes
name_bytes = encode_string(@name)
length = 4 + 1 + name_bytes.size
message_prefix('C', length) + @type + name_bytes
end
end

View File

@@ -1,5 +1,6 @@
require 'socket'
require 'digest/md5'
require_relative 'frontend_messages'
BACKEND_MESSAGE_CODES = {
'Z' => "ReadyForQuery",
@@ -18,9 +19,13 @@ class PostgresSocket
@host = host
@socket = TCPSocket.new @host, @port
@parameters = {}
@verbose = true
@verbose = false
end
def send_message(message)
@socket.write(message.to_bytes)
end
def send_md5_password_message(username, password, salt)
m = Digest::MD5.hexdigest(password + username)
m = Digest::MD5.hexdigest(m + salt.map(&:chr).join(""))
@@ -113,107 +118,6 @@ class PostgresSocket
log "[F] Sent CancelRequest message"
end
def send_query_message(query)
query_size = query.length
message_size = 1 + 4 + query_size
message = []
message << "Q".ord
message << [message_size].pack('l>').unpack('CCCC') # 4
message << query.split('').map(&:ord) # 2, 11
message << 0 # 1, 12
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent Q message (#{query})"
end
def send_parse_message(query)
query_size = query.length
message_size = 2 + 2 + 4 + query_size
message = []
message << "P".ord
message << [message_size].pack('l>').unpack('CCCC') # 4
message << 0 # unnamed statement
message << query.split('').map(&:ord) # 2, 11
message << 0 # 1, 12
message << [0, 0]
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent P message (#{query})"
end
def send_bind_message
message = []
message << "B".ord
message << [12].pack('l>').unpack('CCCC') # 4
message << 0 # unnamed statement
message << 0 # unnamed statement
message << [0, 0] # 2
message << [0, 0] # 2
message << [0, 0] # 2
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent B message"
end
def send_describe_message(mode)
message = []
message << "D".ord
message << [6].pack('l>').unpack('CCCC') # 4
message << mode.ord
message << 0 # unnamed statement
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent D message"
end
def send_execute_message(limit=0)
message = []
message << "E".ord
message << [9].pack('l>').unpack('CCCC') # 4
message << 0 # unnamed statement
message << [limit].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent E message"
end
def send_sync_message
message = []
message << "S".ord
message << [4].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent S message"
end
def send_copydone_message
message = []
message << "c".ord
message << [4].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent c message"
end
def send_copyfail_message
message = []
message << "f".ord
message << [5].pack('l>').unpack('CCCC') # 4
message << 0
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent f message"
end
def send_flush_message
message = []
message << "H".ord
message << [4].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent H message"
end
def read_from_server()
output_messages = []
retry_count = 0

View File

@@ -16,10 +16,14 @@ describe "Portocol handling" do
end
def run_comparison(sequence, socket_a, socket_b)
sequence.each do |msg, *args|
socket_a.send(msg, *args)
socket_b.send(msg, *args)
sequence.each do |msg|
if msg.is_a?(Symbol)
socket_a.send(msg)
socket_b.send(msg)
else
socket_a.send_message(msg)
socket_b.send_message(msg)
end
compare_messages(
socket_a.read_from_server,
socket_b.read_from_server
@@ -83,9 +87,9 @@ describe "Portocol handling" do
context "Cancel Query" do
let(:sequence) {
[
[:send_query_message, "SELECT pg_sleep(5)"],
[:cancel_query]
[
SimpleQueryMessage.new("SELECT pg_sleep(5)"),
:cancel_query
]
}
@@ -95,12 +99,12 @@ describe "Portocol handling" do
xcontext "Simple query after parse" do
let(:sequence) {
[
[:send_parse_message, "SELECT 5"],
[:send_query_message, "SELECT 1"],
[:send_bind_message],
[:send_describe_message, "P"],
[:send_execute_message],
[:send_sync_message],
ParseMessage.new("", "SELECT 5", []),
SimpleQueryMessage.new("SELECT 1"),
BindMessage.new("", "", [], [], [0]),
DescribeMessage.new("P", ""),
ExecuteMessage.new("", 1),
SyncMessage.new
]
}
@@ -111,8 +115,8 @@ describe "Portocol handling" do
xcontext "Flush message" do
let(:sequence) {
[
[:send_parse_message, "SELECT 1"],
[:send_flush_message]
ParseMessage.new("", "SELECT 1", []),
FlushMessage.new
]
}
@@ -122,9 +126,7 @@ describe "Portocol handling" do
xcontext "Bind without parse" do
let(:sequence) {
[
[:send_bind_message]
]
[BindMessage.new("", "", [], [], [0])]
}
# This is known to fail.
# Server responds immediately, Proxy buffers the message
@@ -133,23 +135,155 @@ describe "Portocol handling" do
context "Simple message" do
let(:sequence) {
[[:send_query_message, "SELECT 1"]]
[SimpleQueryMessage.new("SELECT 1")]
}
it_behaves_like "at parity with database"
end
10.times do |i|
context "Extended protocol" do
let(:sequence) {
[
ParseMessage.new("", "SELECT 1", []),
BindMessage.new("", "", [], [], [0]),
DescribeMessage.new("S", ""),
ExecuteMessage.new("", 1),
SyncMessage.new
]
}
context "Extended protocol" do
let(:sequence) {
[
[:send_parse_message, "SELECT 1"],
[:send_bind_message],
[:send_describe_message, "P"],
[:send_execute_message],
[:send_sync_message],
]
}
it_behaves_like "at parity with database"
it_behaves_like "at parity with database"
end
end
end
describe "Protocol-level prepared statements" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 1, "transaction") }
before do
q_sock = PostgresSocket.new('localhost', processes.pgcat.port)
q_sock.send_startup_message("sharding_user", "sharded_db", "sharding_user")
table_query = "CREATE TABLE IF NOT EXISTS employees (employee_id SERIAL PRIMARY KEY, salary NUMERIC(10, 2) CHECK (salary > 0));"
q_sock.send_message(SimpleQueryMessage.new(table_query))
q_sock.close
current_configs = processes.pgcat.current_config
current_configs["pools"]["sharded_db"]["prepared_statements_cache_size"] = 500
processes.pgcat.update_config(current_configs)
processes.pgcat.reload_config
end
after do
q_sock = PostgresSocket.new('localhost', processes.pgcat.port)
q_sock.send_startup_message("sharding_user", "sharded_db", "sharding_user")
table_query = "DROP TABLE IF EXISTS employees;"
q_sock.send_message(SimpleQueryMessage.new(table_query))
q_sock.close
end
context "When unnamed prepared statements are used" do
it "does not cache them" do
socket = PostgresSocket.new('localhost', processes.pgcat.port)
socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
socket.send_message(SimpleQueryMessage.new("DISCARD ALL"))
socket.read_from_server
10.times do |i|
socket.send_message(ParseMessage.new("", "SELECT #{i}", []))
socket.send_message(BindMessage.new("", "", [], [], [0]))
socket.send_message(DescribeMessage.new("S", ""))
socket.send_message(ExecuteMessage.new("", 1))
socket.send_message(SyncMessage.new)
socket.read_from_server
end
socket.send_message(SimpleQueryMessage.new("SELECT name, statement, prepare_time, parameter_types FROM pg_prepared_statements"))
result = socket.read_from_server
number_of_saved_statements = result.count { |m| m[:code] == 'D' }
expect(number_of_saved_statements).to eq(0)
end
end
context "When named prepared statements are used" do
it "caches them" do
socket = PostgresSocket.new('localhost', processes.pgcat.port)
socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
socket.send_message(SimpleQueryMessage.new("DISCARD ALL"))
socket.read_from_server
3.times do
socket.send_message(ParseMessage.new("my_query", "SELECT * FROM employees WHERE employee_id in ($1,$2,$3)", [0,0,0]))
socket.send_message(BindMessage.new("", "my_query", [0,0,0], [0,0,0].map(&:to_s), [0,0,0,0,0,0]))
socket.send_message(SyncMessage.new)
socket.read_from_server
end
3.times do
socket.send_message(ParseMessage.new("my_other_query", "SELECT * FROM employees WHERE salary in ($1,$2,$3)", [0,0,0]))
socket.send_message(BindMessage.new("", "my_other_query", [0,0,0], [0,0,0].map(&:to_s), [0,0,0,0,0,0]))
socket.send_message(SyncMessage.new)
socket.read_from_server
end
socket.send_message(SimpleQueryMessage.new("SELECT name, statement, prepare_time, parameter_types FROM pg_prepared_statements"))
result = socket.read_from_server
number_of_saved_statements = result.count { |m| m[:code] == 'D' }
expect(number_of_saved_statements).to eq(2)
end
end
context "When DISCARD ALL/DEALLOCATE ALL are called" do
it "resets server and client caches" do
socket = PostgresSocket.new('localhost', processes.pgcat.port)
socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
20.times do |i|
socket.send_message(ParseMessage.new("my_query_#{i}", "SELECT * FROM employees WHERE employee_id in ($1,$2,$3)", [0,0,0]))
end
20.times do |i|
socket.send_message(BindMessage.new("", "my_query_#{i}", [0,0,0], [0,0,0].map(&:to_s), [0,0]))
end
socket.send_message(SyncMessage.new)
socket.read_from_server
socket.send_message(SimpleQueryMessage.new("DISCARD ALL"))
socket.read_from_server
responses = []
4.times do |i|
socket.send_message(ParseMessage.new("my_query_#{i}", "SELECT * FROM employees WHERE employee_id in ($1,$2,$3)", [0,0,0]))
socket.send_message(BindMessage.new("", "my_query_#{i}", [0,0,0], [0,0,0].map(&:to_s), [0,0]))
socket.send_message(SyncMessage.new)
responses += socket.read_from_server
end
errors = responses.select { |message| message[:code] == 'E' }
error_message = errors.map { |message| message[:bytes].map(&:chr).join("") }.join("\n")
raise StandardError, "Encountered the following errors: #{error_message}" if errors.length > 0
end
end
context "Maximum number of bound paramters" do
it "does not crash" do
test_socket = PostgresSocket.new('localhost', processes.pgcat.port)
test_socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
types = Array.new(65_535) { |i| 0 }
params = Array.new(65_535) { |i| "$#{i+1}" }.join(",")
test_socket.send_message(ParseMessage.new("my_query", "SELECT * FROM employees WHERE employee_id in (#{params})", types))
test_socket.send_message(BindMessage.new("my_query", "my_query", types, types.map(&:to_s), types))
test_socket.send_message(SyncMessage.new)
# If the proxy crashes, this will raise an error
expect { test_socket.read_from_server }.to_not raise_error
test_socket.close
end
end
end
end