mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
Compare commits
8 Commits
mostafa_fi
...
circleci_A
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fba40eba2f | ||
|
|
d8ccf4babb | ||
|
|
feedcd49d9 | ||
|
|
9bb71ede9d | ||
|
|
88b2afb19b | ||
|
|
f0865ca616 | ||
|
|
7d047c6c19 | ||
|
|
f73d15f82c |
@@ -26,6 +26,7 @@ PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard1 -i
|
||||
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard2 -i
|
||||
|
||||
# Start Toxiproxy
|
||||
kill -9 $(pgrep toxiproxy) || true
|
||||
LOG_LEVEL=error toxiproxy-server &
|
||||
sleep 1
|
||||
|
||||
@@ -106,7 +107,7 @@ cd ../..
|
||||
# These tests will start and stop the pgcat server so it will need to be restarted after the tests
|
||||
#
|
||||
pip3 install -r tests/python/requirements.txt
|
||||
python3 tests/python/tests.py || exit 1
|
||||
pytest || exit 1
|
||||
|
||||
|
||||
#
|
||||
@@ -177,3 +178,6 @@ killall pgcat -s SIGINT
|
||||
|
||||
# Allow for graceful shutdown
|
||||
sleep 1
|
||||
|
||||
kill -9 $(pgrep toxiproxy)
|
||||
sleep 1
|
||||
|
||||
15
.github/workflows/publish-deb-package.yml
vendored
15
.github/workflows/publish-deb-package.yml
vendored
@@ -1,6 +1,9 @@
|
||||
name: pgcat package (deb)
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- v*
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
packageVersion:
|
||||
@@ -16,6 +19,14 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set package version
|
||||
if: github.event_name == 'push' # For push event
|
||||
run: |
|
||||
TAG=${{ github.ref_name }}
|
||||
echo "packageVersion=${TAG#v}" >> "$GITHUB_ENV"
|
||||
- name: Set package version (manual dispatch)
|
||||
if: github.event_name == 'workflow_dispatch' # For manual dispatch
|
||||
run: echo "packageVersion=${{ github.event.inputs.packageVersion }}" >> "$GITHUB_ENV"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: stable
|
||||
@@ -39,10 +50,10 @@ jobs:
|
||||
export ARCH=arm64
|
||||
fi
|
||||
|
||||
bash utilities/deb.sh ${{ inputs.packageVersion }}
|
||||
bash utilities/deb.sh ${{ env.packageVersion }}
|
||||
|
||||
deb-s3 upload \
|
||||
--lock \
|
||||
--bucket apt.postgresml.org \
|
||||
pgcat-${{ inputs.packageVersion }}-ubuntu22.04-${ARCH}.deb \
|
||||
pgcat-${{ env.packageVersion }}-ubuntu22.04-${ARCH}.deb \
|
||||
--codename $(lsb_release -cs)
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -10,4 +10,5 @@ lcov.info
|
||||
dev/.bash_history
|
||||
dev/cache
|
||||
!dev/cache/.keepme
|
||||
.venv
|
||||
.venv
|
||||
**/__pycache__
|
||||
|
||||
@@ -7,7 +7,7 @@ Thank you for contributing! Just a few tips here:
|
||||
3. Performance is important, make sure there are no regressions in your branch vs. `main`.
|
||||
|
||||
## How to run the integration tests locally and iterate on them
|
||||
We have integration tests written in Ruby, Python, Go and Rust.
|
||||
We have integration tests written in Ruby, Python, Go and Rust.
|
||||
Below are the steps to run them in a developer-friendly way that allows iterating and quick turnaround.
|
||||
Hear me out, this should be easy, it will involve opening a shell into a container with all the necessary dependancies available for you and you can modify the test code and immediately rerun your test in the interactive shell.
|
||||
|
||||
@@ -21,7 +21,7 @@ Within this test environment you can modify the file in your favorite IDE and re
|
||||
|
||||
Once the environment is ready, you can run the tests by running
|
||||
Ruby: `cd /app/tests/ruby && bundle exec ruby <test_name>.rb --format documentation`
|
||||
Python: `cd /app && python3 tests/python/tests.py`
|
||||
Python: `cd /app/ && pytest`
|
||||
Rust: `cd /app/tests/rust && cargo run`
|
||||
Go: `cd /app/tests/go && /usr/local/go/bin/go test`
|
||||
|
||||
|
||||
4
postinst
4
postinst
@@ -7,3 +7,7 @@ systemctl enable pgcat
|
||||
if ! id pgcat 2> /dev/null; then
|
||||
useradd -s /usr/bin/false pgcat
|
||||
fi
|
||||
|
||||
if [ -f /etc/pgcat.toml ]; then
|
||||
systemctl start pgcat
|
||||
fi
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::config::AuthType;
|
||||
use crate::errors::Error;
|
||||
use crate::pool::ConnectionPool;
|
||||
use crate::server::Server;
|
||||
@@ -71,6 +72,7 @@ impl AuthPassthrough {
|
||||
pub async fn fetch_hash(&self, address: &crate::config::Address) -> Result<String, Error> {
|
||||
let auth_user = crate::config::User {
|
||||
username: self.user.clone(),
|
||||
auth_type: AuthType::MD5,
|
||||
password: Some(self.password.clone()),
|
||||
server_username: None,
|
||||
server_password: None,
|
||||
|
||||
332
src/client.rs
332
src/client.rs
@@ -14,7 +14,9 @@ use tokio::sync::mpsc::Sender;
|
||||
|
||||
use crate::admin::{generate_server_parameters_for_admin, handle_admin};
|
||||
use crate::auth_passthrough::refetch_auth_hash;
|
||||
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
|
||||
use crate::config::{
|
||||
get_config, get_idle_client_in_transaction_timeout, Address, AuthType, PoolMode,
|
||||
};
|
||||
use crate::constants::*;
|
||||
use crate::messages::*;
|
||||
use crate::plugins::PluginOutput;
|
||||
@@ -463,8 +465,8 @@ where
|
||||
.count()
|
||||
== 1;
|
||||
|
||||
// Kick any client that's not admin while we're in admin-only mode.
|
||||
if !admin && admin_only {
|
||||
// Kick any client that's not admin while we're in admin-only mode.
|
||||
debug!(
|
||||
"Rejecting non-admin connection to {} when in admin only mode",
|
||||
pool_name
|
||||
@@ -481,72 +483,76 @@ where
|
||||
let process_id: i32 = rand::random();
|
||||
let secret_key: i32 = rand::random();
|
||||
|
||||
// Perform MD5 authentication.
|
||||
// TODO: Add SASL support.
|
||||
let salt = md5_challenge(&mut write).await?;
|
||||
|
||||
let code = match read.read_u8().await {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password code".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// PasswordMessage
|
||||
if code as char != 'p' {
|
||||
return Err(Error::ProtocolSyncError(format!(
|
||||
"Expected p, got {}",
|
||||
code as char
|
||||
)));
|
||||
}
|
||||
|
||||
let len = match read.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password message length".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let mut password_response = vec![0u8; (len - 4) as usize];
|
||||
|
||||
match read.read_exact(&mut password_response).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password message".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let mut prepared_statements_enabled = false;
|
||||
|
||||
// Authenticate admin user.
|
||||
let (transaction_mode, mut server_parameters) = if admin {
|
||||
let config = get_config();
|
||||
// TODO: Add SASL support.
|
||||
// Perform MD5 authentication.
|
||||
match config.general.admin_auth_type {
|
||||
AuthType::Trust => (),
|
||||
AuthType::MD5 => {
|
||||
let salt = md5_challenge(&mut write).await?;
|
||||
|
||||
// Compare server and client hashes.
|
||||
let password_hash = md5_hash_password(
|
||||
&config.general.admin_username,
|
||||
&config.general.admin_password,
|
||||
&salt,
|
||||
);
|
||||
let code = match read.read_u8().await {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password code".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
if password_hash != password_response {
|
||||
let error = Error::ClientGeneralError("Invalid password".into(), client_identifier);
|
||||
// PasswordMessage
|
||||
if code as char != 'p' {
|
||||
return Err(Error::ProtocolSyncError(format!(
|
||||
"Expected p, got {}",
|
||||
code as char
|
||||
)));
|
||||
}
|
||||
|
||||
warn!("{}", error);
|
||||
wrong_password(&mut write, username).await?;
|
||||
let len = match read.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password message length".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
return Err(error);
|
||||
let mut password_response = vec![0u8; (len - 4) as usize];
|
||||
|
||||
match read.read_exact(&mut password_response).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password message".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Compare server and client hashes.
|
||||
let password_hash = md5_hash_password(
|
||||
&config.general.admin_username,
|
||||
&config.general.admin_password,
|
||||
&salt,
|
||||
);
|
||||
|
||||
if password_hash != password_response {
|
||||
let error =
|
||||
Error::ClientGeneralError("Invalid password".into(), client_identifier);
|
||||
|
||||
warn!("{}", error);
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(false, generate_server_parameters_for_admin())
|
||||
}
|
||||
// Authenticate normal user.
|
||||
@@ -573,92 +579,143 @@ where
|
||||
// Obtain the hash to compare, we give preference to that written in cleartext in config
|
||||
// if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained
|
||||
// when the pool was created. If there is no hash there, we try to fetch it one more time.
|
||||
let password_hash = if let Some(password) = &pool.settings.user.password {
|
||||
Some(md5_hash_password(username, password, &salt))
|
||||
} else {
|
||||
if !get_config().is_auth_query_configured() {
|
||||
wrong_password(&mut write, username).await?;
|
||||
return Err(Error::ClientAuthImpossible(username.into()));
|
||||
}
|
||||
match pool.settings.user.auth_type {
|
||||
AuthType::Trust => (),
|
||||
AuthType::MD5 => {
|
||||
// Perform MD5 authentication.
|
||||
// TODO: Add SASL support.
|
||||
let salt = md5_challenge(&mut write).await?;
|
||||
|
||||
let mut hash = (*pool.auth_hash.read()).clone();
|
||||
let code = match read.read_u8().await {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password code".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
if hash.is_none() {
|
||||
warn!(
|
||||
"Query auth configured \
|
||||
but no hash password found \
|
||||
for pool {}. Will try to refetch it.",
|
||||
pool_name
|
||||
);
|
||||
// PasswordMessage
|
||||
if code as char != 'p' {
|
||||
return Err(Error::ProtocolSyncError(format!(
|
||||
"Expected p, got {}",
|
||||
code as char
|
||||
)));
|
||||
}
|
||||
|
||||
match refetch_auth_hash(&pool).await {
|
||||
Ok(fetched_hash) => {
|
||||
warn!("Password for {}, obtained. Updating.", client_identifier);
|
||||
let len = match read.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password message length".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let mut password_response = vec![0u8; (len - 4) as usize];
|
||||
|
||||
match read.read_exact(&mut password_response).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password message".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let password_hash = if let Some(password) = &pool.settings.user.password {
|
||||
Some(md5_hash_password(username, password, &salt))
|
||||
} else {
|
||||
if !get_config().is_auth_query_configured() {
|
||||
wrong_password(&mut write, username).await?;
|
||||
return Err(Error::ClientAuthImpossible(username.into()));
|
||||
}
|
||||
|
||||
let mut hash = (*pool.auth_hash.read()).clone();
|
||||
|
||||
if hash.is_none() {
|
||||
warn!(
|
||||
"Query auth configured \
|
||||
but no hash password found \
|
||||
for pool {}. Will try to refetch it.",
|
||||
pool_name
|
||||
);
|
||||
|
||||
match refetch_auth_hash(&pool).await {
|
||||
Ok(fetched_hash) => {
|
||||
warn!(
|
||||
"Password for {}, obtained. Updating.",
|
||||
client_identifier
|
||||
);
|
||||
|
||||
{
|
||||
let mut pool_auth_hash = pool.auth_hash.write();
|
||||
*pool_auth_hash = Some(fetched_hash.clone());
|
||||
}
|
||||
|
||||
hash = Some(fetched_hash);
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(Error::ClientAuthPassthroughError(
|
||||
err.to_string(),
|
||||
client_identifier,
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Some(md5_hash_second_pass(&hash.unwrap(), &salt))
|
||||
};
|
||||
|
||||
// Once we have the resulting hash, we compare with what the client gave us.
|
||||
// If they do not match and auth query is set up, we try to refetch the hash one more time
|
||||
// to see if the password has changed since the pool was created.
|
||||
//
|
||||
// @TODO: we could end up fetching again the same password twice (see above).
|
||||
if password_hash.unwrap() != password_response {
|
||||
warn!(
|
||||
"Invalid password {}, will try to refetch it.",
|
||||
client_identifier
|
||||
);
|
||||
|
||||
let fetched_hash = match refetch_auth_hash(&pool).await {
|
||||
Ok(fetched_hash) => fetched_hash,
|
||||
Err(err) => {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
|
||||
|
||||
// Ok password changed in server an auth is possible.
|
||||
if new_password_hash == password_response {
|
||||
warn!(
|
||||
"Password for {}, changed in server. Updating.",
|
||||
client_identifier
|
||||
);
|
||||
|
||||
{
|
||||
let mut pool_auth_hash = pool.auth_hash.write();
|
||||
*pool_auth_hash = Some(fetched_hash.clone());
|
||||
*pool_auth_hash = Some(fetched_hash);
|
||||
}
|
||||
|
||||
hash = Some(fetched_hash);
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
} else {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(Error::ClientAuthPassthroughError(
|
||||
err.to_string(),
|
||||
return Err(Error::ClientGeneralError(
|
||||
"Invalid password".into(),
|
||||
client_identifier,
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Some(md5_hash_second_pass(&hash.unwrap(), &salt))
|
||||
};
|
||||
|
||||
// Once we have the resulting hash, we compare with what the client gave us.
|
||||
// If they do not match and auth query is set up, we try to refetch the hash one more time
|
||||
// to see if the password has changed since the pool was created.
|
||||
//
|
||||
// @TODO: we could end up fetching again the same password twice (see above).
|
||||
if password_hash.unwrap() != password_response {
|
||||
warn!(
|
||||
"Invalid password {}, will try to refetch it.",
|
||||
client_identifier
|
||||
);
|
||||
|
||||
let fetched_hash = match refetch_auth_hash(&pool).await {
|
||||
Ok(fetched_hash) => fetched_hash,
|
||||
Err(err) => {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
|
||||
|
||||
// Ok password changed in server an auth is possible.
|
||||
if new_password_hash == password_response {
|
||||
warn!(
|
||||
"Password for {}, changed in server. Updating.",
|
||||
client_identifier
|
||||
);
|
||||
|
||||
{
|
||||
let mut pool_auth_hash = pool.auth_hash.write();
|
||||
*pool_auth_hash = Some(fetched_hash);
|
||||
}
|
||||
} else {
|
||||
wrong_password(&mut write, username).await?;
|
||||
return Err(Error::ClientGeneralError(
|
||||
"Invalid password".into(),
|
||||
client_identifier,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
|
||||
prepared_statements_enabled =
|
||||
transaction_mode && pool.prepared_statement_cache.is_some();
|
||||
@@ -1729,13 +1786,14 @@ where
|
||||
/// and also the pool's statement cache. Add it to extended protocol data.
|
||||
fn buffer_parse(&mut self, message: BytesMut, pool: &ConnectionPool) -> Result<(), Error> {
|
||||
// Avoid parsing if prepared statements not enabled
|
||||
let client_given_name = Parse::get_name(&message)?;
|
||||
if !self.prepared_statements_enabled || client_given_name.is_empty() {
|
||||
if !self.prepared_statements_enabled {
|
||||
debug!("Anonymous parse message");
|
||||
self.extended_protocol_data_buffer
|
||||
.push_back(ExtendedProtocolData::create_new_parse(message, None));
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let client_given_name = Parse::get_name(&message)?;
|
||||
let parse: Parse = (&message).try_into()?;
|
||||
|
||||
// Compute the hash of the parse statement
|
||||
@@ -1773,14 +1831,15 @@ where
|
||||
/// saved in the client cache.
|
||||
async fn buffer_bind(&mut self, message: BytesMut) -> Result<(), Error> {
|
||||
// Avoid parsing if prepared statements not enabled
|
||||
let client_given_name = Bind::get_name(&message)?;
|
||||
if !self.prepared_statements_enabled || client_given_name.is_empty() {
|
||||
if !self.prepared_statements_enabled {
|
||||
debug!("Anonymous bind message");
|
||||
self.extended_protocol_data_buffer
|
||||
.push_back(ExtendedProtocolData::create_new_bind(message, None));
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let client_given_name = Bind::get_name(&message)?;
|
||||
|
||||
match self.prepared_statements.get(&client_given_name) {
|
||||
Some((rewritten_parse, _)) => {
|
||||
let message = Bind::rename(message, &rewritten_parse.name)?;
|
||||
@@ -1832,8 +1891,7 @@ where
|
||||
}
|
||||
|
||||
let describe: Describe = (&message).try_into()?;
|
||||
let client_given_name = describe.statement_name.clone();
|
||||
if describe.target == 'P' || client_given_name.is_empty() {
|
||||
if describe.target == 'P' {
|
||||
debug!("Portal describe message");
|
||||
self.extended_protocol_data_buffer
|
||||
.push_back(ExtendedProtocolData::create_new_describe(message, None));
|
||||
@@ -1841,6 +1899,8 @@ where
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let client_given_name = describe.statement_name.clone();
|
||||
|
||||
match self.prepared_statements.get(&client_given_name) {
|
||||
Some((rewritten_parse, _)) => {
|
||||
let describe = describe.rename(&rewritten_parse.name);
|
||||
|
||||
@@ -208,6 +208,9 @@ impl Address {
|
||||
pub struct User {
|
||||
pub username: String,
|
||||
pub password: Option<String>,
|
||||
|
||||
#[serde(default = "User::default_auth_type")]
|
||||
pub auth_type: AuthType,
|
||||
pub server_username: Option<String>,
|
||||
pub server_password: Option<String>,
|
||||
pub pool_size: u32,
|
||||
@@ -225,6 +228,7 @@ impl Default for User {
|
||||
User {
|
||||
username: String::from("postgres"),
|
||||
password: None,
|
||||
auth_type: AuthType::MD5,
|
||||
server_username: None,
|
||||
server_password: None,
|
||||
pool_size: 15,
|
||||
@@ -239,6 +243,10 @@ impl Default for User {
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub fn default_auth_type() -> AuthType {
|
||||
AuthType::MD5
|
||||
}
|
||||
|
||||
fn validate(&self) -> Result<(), Error> {
|
||||
if let Some(min_pool_size) = self.min_pool_size {
|
||||
if min_pool_size > self.pool_size {
|
||||
@@ -334,6 +342,9 @@ pub struct General {
|
||||
pub admin_username: String,
|
||||
pub admin_password: String,
|
||||
|
||||
#[serde(default = "General::default_admin_auth_type")]
|
||||
pub admin_auth_type: AuthType,
|
||||
|
||||
#[serde(default = "General::default_validate_config")]
|
||||
pub validate_config: bool,
|
||||
|
||||
@@ -348,6 +359,10 @@ impl General {
|
||||
"0.0.0.0".into()
|
||||
}
|
||||
|
||||
pub fn default_admin_auth_type() -> AuthType {
|
||||
AuthType::MD5
|
||||
}
|
||||
|
||||
pub fn default_port() -> u16 {
|
||||
5432
|
||||
}
|
||||
@@ -456,6 +471,7 @@ impl Default for General {
|
||||
verify_server_certificate: false,
|
||||
admin_username: String::from("admin"),
|
||||
admin_password: String::from("admin"),
|
||||
admin_auth_type: AuthType::MD5,
|
||||
validate_config: true,
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
@@ -476,6 +492,15 @@ pub enum PoolMode {
|
||||
Session,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Copy, Hash)]
|
||||
pub enum AuthType {
|
||||
#[serde(alias = "trust", alias = "Trust")]
|
||||
Trust,
|
||||
|
||||
#[serde(alias = "md5", alias = "MD5")]
|
||||
MD5,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PoolMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
|
||||
@@ -821,10 +821,10 @@ impl ExtendedProtocolData {
|
||||
pub struct Parse {
|
||||
code: char,
|
||||
#[allow(dead_code)]
|
||||
len: u32,
|
||||
len: i32,
|
||||
pub name: String,
|
||||
query: String,
|
||||
num_params: u16,
|
||||
num_params: i16,
|
||||
param_types: Vec<i32>,
|
||||
}
|
||||
|
||||
@@ -834,11 +834,12 @@ impl TryFrom<&BytesMut> for Parse {
|
||||
fn try_from(buf: &BytesMut) -> Result<Parse, Error> {
|
||||
let mut cursor = Cursor::new(buf);
|
||||
let code = cursor.get_u8() as char;
|
||||
let len = cursor.get_u32();
|
||||
let len = cursor.get_i32();
|
||||
let name = cursor.read_string()?;
|
||||
let query = cursor.read_string()?;
|
||||
let num_params = cursor.get_u16();
|
||||
let num_params = cursor.get_i16();
|
||||
let mut param_types = Vec::new();
|
||||
|
||||
for _ in 0..num_params {
|
||||
param_types.push(cursor.get_i32());
|
||||
}
|
||||
@@ -874,10 +875,10 @@ impl TryFrom<Parse> for BytesMut {
|
||||
+ 4 * parse.num_params as usize;
|
||||
|
||||
bytes.put_u8(parse.code as u8);
|
||||
bytes.put_u32(len as u32);
|
||||
bytes.put_i32(len as i32);
|
||||
bytes.put_slice(name);
|
||||
bytes.put_slice(query);
|
||||
bytes.put_u16(parse.num_params);
|
||||
bytes.put_i16(parse.num_params);
|
||||
for param in parse.param_types {
|
||||
bytes.put_i32(param);
|
||||
}
|
||||
@@ -944,14 +945,14 @@ impl Parse {
|
||||
pub struct Bind {
|
||||
code: char,
|
||||
#[allow(dead_code)]
|
||||
len: u64,
|
||||
len: i64,
|
||||
portal: String,
|
||||
pub prepared_statement: String,
|
||||
num_param_format_codes: u16,
|
||||
num_param_format_codes: i16,
|
||||
param_format_codes: Vec<i16>,
|
||||
num_param_values: u16,
|
||||
num_param_values: i16,
|
||||
param_values: Vec<(i32, BytesMut)>,
|
||||
num_result_column_format_codes: u16,
|
||||
num_result_column_format_codes: i16,
|
||||
result_columns_format_codes: Vec<i16>,
|
||||
}
|
||||
|
||||
@@ -961,17 +962,17 @@ impl TryFrom<&BytesMut> for Bind {
|
||||
fn try_from(buf: &BytesMut) -> Result<Bind, Error> {
|
||||
let mut cursor = Cursor::new(buf);
|
||||
let code = cursor.get_u8() as char;
|
||||
let len = cursor.get_u32();
|
||||
let len = cursor.get_i32();
|
||||
let portal = cursor.read_string()?;
|
||||
let prepared_statement = cursor.read_string()?;
|
||||
let num_param_format_codes = cursor.get_u16();
|
||||
let num_param_format_codes = cursor.get_i16();
|
||||
let mut param_format_codes = Vec::new();
|
||||
|
||||
for _ in 0..num_param_format_codes {
|
||||
param_format_codes.push(cursor.get_i16());
|
||||
}
|
||||
|
||||
let num_param_values = cursor.get_u16();
|
||||
let num_param_values = cursor.get_i16();
|
||||
let mut param_values = Vec::new();
|
||||
|
||||
for _ in 0..num_param_values {
|
||||
@@ -993,7 +994,7 @@ impl TryFrom<&BytesMut> for Bind {
|
||||
}
|
||||
}
|
||||
|
||||
let num_result_column_format_codes = cursor.get_u16();
|
||||
let num_result_column_format_codes = cursor.get_i16();
|
||||
let mut result_columns_format_codes = Vec::new();
|
||||
|
||||
for _ in 0..num_result_column_format_codes {
|
||||
@@ -1002,7 +1003,7 @@ impl TryFrom<&BytesMut> for Bind {
|
||||
|
||||
Ok(Bind {
|
||||
code,
|
||||
len: len as u64,
|
||||
len: len as i64,
|
||||
portal,
|
||||
prepared_statement,
|
||||
num_param_format_codes,
|
||||
@@ -1041,19 +1042,19 @@ impl TryFrom<Bind> for BytesMut {
|
||||
len += 2 * bind.num_result_column_format_codes as usize;
|
||||
|
||||
bytes.put_u8(bind.code as u8);
|
||||
bytes.put_u32(len as u32);
|
||||
bytes.put_i32(len as i32);
|
||||
bytes.put_slice(portal);
|
||||
bytes.put_slice(prepared_statement);
|
||||
bytes.put_u16(bind.num_param_format_codes);
|
||||
bytes.put_i16(bind.num_param_format_codes);
|
||||
for param_format_code in bind.param_format_codes {
|
||||
bytes.put_i16(param_format_code);
|
||||
}
|
||||
bytes.put_u16(bind.num_param_values);
|
||||
bytes.put_i16(bind.num_param_values);
|
||||
for (param_len, param) in bind.param_values {
|
||||
bytes.put_i32(param_len);
|
||||
bytes.put_slice(¶m);
|
||||
}
|
||||
bytes.put_u16(bind.num_result_column_format_codes);
|
||||
bytes.put_i16(bind.num_result_column_format_codes);
|
||||
for result_column_format_code in bind.result_columns_format_codes {
|
||||
bytes.put_i16(result_column_format_code);
|
||||
}
|
||||
@@ -1067,7 +1068,7 @@ impl Bind {
|
||||
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
|
||||
let mut cursor = Cursor::new(buf);
|
||||
// Skip the code and length
|
||||
cursor.advance(mem::size_of::<u8>() + mem::size_of::<u32>());
|
||||
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
|
||||
cursor.read_string()?;
|
||||
cursor.read_string()
|
||||
}
|
||||
@@ -1077,17 +1078,17 @@ impl Bind {
|
||||
let mut cursor = Cursor::new(&buf);
|
||||
// Read basic data from the cursor
|
||||
let code = cursor.get_u8();
|
||||
let current_len = cursor.get_u32();
|
||||
let current_len = cursor.get_i32();
|
||||
let portal = cursor.read_string()?;
|
||||
let prepared_statement = cursor.read_string()?;
|
||||
|
||||
// Calculate new length
|
||||
let new_len = current_len + new_name.len() as u32 - prepared_statement.len() as u32;
|
||||
let new_len = current_len + new_name.len() as i32 - prepared_statement.len() as i32;
|
||||
|
||||
// Begin building the response buffer
|
||||
let mut response_buf = BytesMut::with_capacity(new_len as usize + 1);
|
||||
response_buf.put_u8(code);
|
||||
response_buf.put_u32(new_len);
|
||||
response_buf.put_i32(new_len);
|
||||
|
||||
// Put the portal and new name into the buffer
|
||||
// Note: panic if the provided string contains null byte
|
||||
@@ -1111,7 +1112,7 @@ pub struct Describe {
|
||||
code: char,
|
||||
|
||||
#[allow(dead_code)]
|
||||
len: u32,
|
||||
len: i32,
|
||||
pub target: char,
|
||||
pub statement_name: String,
|
||||
}
|
||||
@@ -1122,7 +1123,7 @@ impl TryFrom<&BytesMut> for Describe {
|
||||
fn try_from(bytes: &BytesMut) -> Result<Describe, Error> {
|
||||
let mut cursor = Cursor::new(bytes);
|
||||
let code = cursor.get_u8() as char;
|
||||
let len = cursor.get_u32();
|
||||
let len = cursor.get_i32();
|
||||
let target = cursor.get_u8() as char;
|
||||
let statement_name = cursor.read_string()?;
|
||||
|
||||
@@ -1145,7 +1146,7 @@ impl TryFrom<Describe> for BytesMut {
|
||||
let len = 4 + 1 + statement_name.len();
|
||||
|
||||
bytes.put_u8(describe.code as u8);
|
||||
bytes.put_u32(len as u32);
|
||||
bytes.put_i32(len as i32);
|
||||
bytes.put_u8(describe.target as u8);
|
||||
bytes.put_slice(statement_name);
|
||||
|
||||
|
||||
@@ -200,18 +200,17 @@ struct PrometheusMetric<Value: fmt::Display> {
|
||||
|
||||
impl<Value: fmt::Display> fmt::Display for PrometheusMetric<Value> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
let formatted_labels = self
|
||||
.labels
|
||||
let mut sorted_labels: Vec<_> = self.labels.iter().collect();
|
||||
sorted_labels.sort_by_key(|&(key, _)| key);
|
||||
let formatted_labels = sorted_labels
|
||||
.iter()
|
||||
.map(|(key, value)| format!("{}=\"{}\"", key, value))
|
||||
.collect::<Vec<_>>()
|
||||
.join(",");
|
||||
write!(
|
||||
f,
|
||||
"# HELP {name} {help}\n# TYPE {name} {ty}\n{name}{{{formatted_labels}}} {value}\n",
|
||||
"{name}{{{formatted_labels}}} {value}",
|
||||
name = format_args!("pgcat_{}", self.name),
|
||||
help = self.help,
|
||||
ty = self.ty,
|
||||
formatted_labels = formatted_labels,
|
||||
value = self.value
|
||||
)
|
||||
@@ -247,7 +246,7 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
|
||||
labels.insert("pool", address.pool_name.clone());
|
||||
labels.insert("index", address.address_index.to_string());
|
||||
labels.insert("database", address.database.to_string());
|
||||
labels.insert("user", address.username.clone());
|
||||
labels.insert("username", address.username.clone());
|
||||
|
||||
Self::from_name(&format!("databases_{}", name), value, labels)
|
||||
}
|
||||
@@ -264,7 +263,8 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
|
||||
labels.insert("pool", address.pool_name.clone());
|
||||
labels.insert("index", address.address_index.to_string());
|
||||
labels.insert("database", address.database.to_string());
|
||||
labels.insert("user", address.username.clone());
|
||||
labels.insert("username", address.username.clone());
|
||||
|
||||
Self::from_name(&format!("servers_{}", name), value, labels)
|
||||
}
|
||||
|
||||
@@ -276,7 +276,7 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
|
||||
labels.insert("role", address.role.to_string());
|
||||
labels.insert("index", address.address_index.to_string());
|
||||
labels.insert("database", address.database.to_string());
|
||||
labels.insert("user", address.username.clone());
|
||||
labels.insert("username", address.username.clone());
|
||||
|
||||
Self::from_name(&format!("stats_{}", name), value, labels)
|
||||
}
|
||||
@@ -288,6 +288,15 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
|
||||
|
||||
Self::from_name(&format!("pools_{}", name), value, labels)
|
||||
}
|
||||
|
||||
fn get_header(&self) -> String {
|
||||
format!(
|
||||
"\n# HELP {name} {help}\n# TYPE {name} {ty}",
|
||||
name = format_args!("pgcat_{}", self.name),
|
||||
help = self.help,
|
||||
ty = self.ty,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
async fn prometheus_stats(
|
||||
@@ -313,6 +322,7 @@ async fn prometheus_stats(
|
||||
|
||||
// Adds metrics shown in a SHOW STATS admin command.
|
||||
fn push_address_stats(lines: &mut Vec<String>) {
|
||||
let mut grouped_metrics: HashMap<String, Vec<PrometheusMetric<u64>>> = HashMap::new();
|
||||
for (_, pool) in get_all_pools() {
|
||||
for shard in 0..pool.shards() {
|
||||
for server in 0..pool.servers(shard) {
|
||||
@@ -322,7 +332,10 @@ fn push_address_stats(lines: &mut Vec<String>) {
|
||||
if let Some(prometheus_metric) =
|
||||
PrometheusMetric::<u64>::from_address(address, &key, value)
|
||||
{
|
||||
lines.push(prometheus_metric.to_string());
|
||||
grouped_metrics
|
||||
.entry(key)
|
||||
.or_default()
|
||||
.push(prometheus_metric);
|
||||
} else {
|
||||
debug!("Metric {} not implemented for {}", key, address.name());
|
||||
}
|
||||
@@ -330,33 +343,53 @@ fn push_address_stats(lines: &mut Vec<String>) {
|
||||
}
|
||||
}
|
||||
}
|
||||
for (_key, metrics) in grouped_metrics {
|
||||
if !metrics.is_empty() {
|
||||
lines.push(metrics[0].get_header());
|
||||
for metric in metrics {
|
||||
lines.push(metric.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Adds relevant metrics shown in a SHOW POOLS admin command.
|
||||
fn push_pool_stats(lines: &mut Vec<String>) {
|
||||
let mut grouped_metrics: HashMap<String, Vec<PrometheusMetric<u64>>> = HashMap::new();
|
||||
let pool_stats = PoolStats::construct_pool_lookup();
|
||||
for (pool_id, stats) in pool_stats.iter() {
|
||||
for (name, value) in stats.clone() {
|
||||
if let Some(prometheus_metric) =
|
||||
PrometheusMetric::<u64>::from_pool(pool_id.clone(), &name, value)
|
||||
{
|
||||
lines.push(prometheus_metric.to_string());
|
||||
grouped_metrics
|
||||
.entry(name)
|
||||
.or_default()
|
||||
.push(prometheus_metric);
|
||||
} else {
|
||||
debug!("Metric {} not implemented for ({})", name, *pool_id);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (_key, metrics) in grouped_metrics {
|
||||
if !metrics.is_empty() {
|
||||
lines.push(metrics[0].get_header());
|
||||
for metric in metrics {
|
||||
lines.push(metric.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Adds relevant metrics shown in a SHOW DATABASES admin command.
|
||||
fn push_database_stats(lines: &mut Vec<String>) {
|
||||
let mut grouped_metrics: HashMap<String, Vec<PrometheusMetric<u32>>> = HashMap::new();
|
||||
for (_, pool) in get_all_pools() {
|
||||
let pool_config = pool.settings.clone();
|
||||
for shard in 0..pool.shards() {
|
||||
for server in 0..pool.servers(shard) {
|
||||
let address = pool.address(shard, server);
|
||||
let pool_state = pool.pool_state(shard, server);
|
||||
|
||||
let metrics = vec![
|
||||
("pool_size", pool_config.user.pool_size),
|
||||
("current_connections", pool_state.connections),
|
||||
@@ -365,7 +398,10 @@ fn push_database_stats(lines: &mut Vec<String>) {
|
||||
if let Some(prometheus_metric) =
|
||||
PrometheusMetric::<u32>::from_database_info(address, key, value)
|
||||
{
|
||||
lines.push(prometheus_metric.to_string());
|
||||
grouped_metrics
|
||||
.entry(key.to_string())
|
||||
.or_default()
|
||||
.push(prometheus_metric);
|
||||
} else {
|
||||
debug!("Metric {} not implemented for {}", key, address.name());
|
||||
}
|
||||
@@ -373,6 +409,14 @@ fn push_database_stats(lines: &mut Vec<String>) {
|
||||
}
|
||||
}
|
||||
}
|
||||
for (_key, metrics) in grouped_metrics {
|
||||
if !metrics.is_empty() {
|
||||
lines.push(metrics[0].get_header());
|
||||
for metric in metrics {
|
||||
lines.push(metric.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Adds relevant metrics shown in a SHOW SERVERS admin command.
|
||||
@@ -405,7 +449,7 @@ fn push_server_stats(lines: &mut Vec<String>) {
|
||||
crate::stats::ServerState::Idle => entry.idle_count += 1,
|
||||
}
|
||||
}
|
||||
|
||||
let mut grouped_metrics: HashMap<String, Vec<PrometheusMetric<u64>>> = HashMap::new();
|
||||
for (_, pool) in get_all_pools() {
|
||||
for shard in 0..pool.shards() {
|
||||
for server in 0..pool.servers(shard) {
|
||||
@@ -428,7 +472,10 @@ fn push_server_stats(lines: &mut Vec<String>) {
|
||||
if let Some(prometheus_metric) =
|
||||
PrometheusMetric::<u64>::from_server_info(address, key, value)
|
||||
{
|
||||
lines.push(prometheus_metric.to_string());
|
||||
grouped_metrics
|
||||
.entry(key.to_string())
|
||||
.or_default()
|
||||
.push(prometheus_metric);
|
||||
} else {
|
||||
debug!("Metric {} not implemented for {}", key, address.name());
|
||||
}
|
||||
@@ -437,6 +484,14 @@ fn push_server_stats(lines: &mut Vec<String>) {
|
||||
}
|
||||
}
|
||||
}
|
||||
for (_key, metrics) in grouped_metrics {
|
||||
if !metrics.is_empty() {
|
||||
lines.push(metrics[0].get_header());
|
||||
for metric in metrics {
|
||||
lines.push(metric.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn start_metric_server(http_addr: SocketAddr) {
|
||||
|
||||
@@ -698,6 +698,7 @@ impl Server {
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
trace!("Error: {}", error_code);
|
||||
|
||||
match error_code {
|
||||
@@ -1012,12 +1013,6 @@ impl Server {
|
||||
// which can leak between clients. This is a best effort to block bad clients
|
||||
// from poisoning a transaction-mode pool by setting inappropriate session variables
|
||||
match command.as_str() {
|
||||
"DISCARD ALL" => {
|
||||
self.clear_prepared_statement_cache();
|
||||
}
|
||||
"DEALLOCATE ALL" => {
|
||||
self.clear_prepared_statement_cache();
|
||||
}
|
||||
"SET" => {
|
||||
// We don't detect set statements in transactions
|
||||
// No great way to differentiate between set and set local
|
||||
@@ -1137,12 +1132,6 @@ impl Server {
|
||||
has_it
|
||||
}
|
||||
|
||||
fn clear_prepared_statement_cache(&mut self) {
|
||||
if let Some(cache) = &mut self.prepared_statement_cache {
|
||||
cache.clear();
|
||||
}
|
||||
}
|
||||
|
||||
fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option<String> {
|
||||
let cache = match &mut self.prepared_statement_cache {
|
||||
Some(cache) => cache,
|
||||
|
||||
@@ -23,11 +23,11 @@ docker compose exec --workdir /app/tests/python main pip3 install -r requirement
|
||||
echo "Interactive test environment ready"
|
||||
echo "To run integration tests, you can use the following commands:"
|
||||
echo -e " ${BLUE}Ruby: ${RED}cd /app/tests/ruby && bundle exec ruby tests.rb --format documentation${RESET}"
|
||||
echo -e " ${BLUE}Python: ${RED}cd /app && python3 tests/python/tests.py${RESET}"
|
||||
echo -e " ${BLUE}Python: ${RED}cd /app/ && pytest ${RESET}"
|
||||
echo -e " ${BLUE}Rust: ${RED}cd /app/tests/rust && cargo run ${RESET}"
|
||||
echo -e " ${BLUE}Go: ${RED}cd /app/tests/go && /usr/local/go/bin/go test${RESET}"
|
||||
echo "the source code for tests are directly linked to the source code in the container so you can modify the code and run the tests again"
|
||||
echo "You can rebuild PgCat from within the container by running"
|
||||
echo "You can rebuild PgCat from within the container by running"
|
||||
echo -e " ${GREEN}cargo build${RESET}"
|
||||
echo "and then run the tests again"
|
||||
echo "==================================="
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
pytest
|
||||
psycopg2==2.9.3
|
||||
psutil==5.9.1
|
||||
psutil==5.9.1
|
||||
|
||||
71
tests/python/test_auth.py
Normal file
71
tests/python/test_auth.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import utils
|
||||
import signal
|
||||
|
||||
class TestTrustAuth:
|
||||
@classmethod
|
||||
def setup_method(cls):
|
||||
config= """
|
||||
[general]
|
||||
host = "0.0.0.0"
|
||||
port = 6432
|
||||
admin_username = "admin_user"
|
||||
admin_password = ""
|
||||
admin_auth_type = "trust"
|
||||
|
||||
[pools.sharded_db.users.0]
|
||||
username = "sharding_user"
|
||||
password = "sharding_user"
|
||||
auth_type = "trust"
|
||||
pool_size = 10
|
||||
min_pool_size = 1
|
||||
pool_mode = "transaction"
|
||||
|
||||
[pools.sharded_db.shards.0]
|
||||
servers = [
|
||||
[ "127.0.0.1", 5432, "primary" ],
|
||||
]
|
||||
database = "shard0"
|
||||
"""
|
||||
utils.pgcat_generic_start(config)
|
||||
|
||||
@classmethod
|
||||
def teardown_method(self):
|
||||
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||
|
||||
def test_admin_trust_auth(self):
|
||||
conn, cur = utils.connect_db_trust(admin=True)
|
||||
cur.execute("SHOW POOLS")
|
||||
res = cur.fetchall()
|
||||
print(res)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
|
||||
def test_normal_trust_auth(self):
|
||||
conn, cur = utils.connect_db_trust(autocommit=False)
|
||||
cur.execute("SELECT 1")
|
||||
res = cur.fetchall()
|
||||
print(res)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
|
||||
class TestMD5Auth:
|
||||
@classmethod
|
||||
def setup_method(cls):
|
||||
utils.pgcat_start()
|
||||
|
||||
@classmethod
|
||||
def teardown_method(self):
|
||||
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||
|
||||
def test_normal_db_access(self):
|
||||
conn, cur = utils.connect_db(autocommit=False)
|
||||
cur.execute("SELECT 1")
|
||||
res = cur.fetchall()
|
||||
print(res)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
|
||||
def test_admin_db_access(self):
|
||||
conn, cur = utils.connect_db(admin=True)
|
||||
|
||||
cur.execute("SHOW POOLS")
|
||||
res = cur.fetchall()
|
||||
print(res)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
@@ -1,84 +1,12 @@
|
||||
from typing import Tuple
|
||||
import psycopg2
|
||||
import psutil
|
||||
import os
|
||||
|
||||
import signal
|
||||
import time
|
||||
|
||||
import psycopg2
|
||||
import utils
|
||||
|
||||
SHUTDOWN_TIMEOUT = 5
|
||||
|
||||
PGCAT_HOST = "127.0.0.1"
|
||||
PGCAT_PORT = "6432"
|
||||
|
||||
|
||||
def pgcat_start():
|
||||
pg_cat_send_signal(signal.SIGTERM)
|
||||
os.system("./target/debug/pgcat .circleci/pgcat.toml &")
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def pg_cat_send_signal(signal: signal.Signals):
|
||||
try:
|
||||
for proc in psutil.process_iter(["pid", "name"]):
|
||||
if "pgcat" == proc.name():
|
||||
os.kill(proc.pid, signal)
|
||||
except Exception as e:
|
||||
# The process can be gone when we send this signal
|
||||
print(e)
|
||||
|
||||
if signal == signal.SIGTERM:
|
||||
# Returns 0 if pgcat process exists
|
||||
time.sleep(2)
|
||||
if not os.system('pgrep pgcat'):
|
||||
raise Exception("pgcat not closed after SIGTERM")
|
||||
|
||||
|
||||
def connect_db(
|
||||
autocommit: bool = True,
|
||||
admin: bool = False,
|
||||
) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]:
|
||||
|
||||
if admin:
|
||||
user = "admin_user"
|
||||
password = "admin_pass"
|
||||
db = "pgcat"
|
||||
else:
|
||||
user = "sharding_user"
|
||||
password = "sharding_user"
|
||||
db = "sharded_db"
|
||||
|
||||
conn = psycopg2.connect(
|
||||
f"postgres://{user}:{password}@{PGCAT_HOST}:{PGCAT_PORT}/{db}?application_name=testing_pgcat",
|
||||
connect_timeout=2,
|
||||
)
|
||||
conn.autocommit = autocommit
|
||||
cur = conn.cursor()
|
||||
|
||||
return (conn, cur)
|
||||
|
||||
|
||||
def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.cursor):
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
def test_normal_db_access():
|
||||
pgcat_start()
|
||||
conn, cur = connect_db(autocommit=False)
|
||||
cur.execute("SELECT 1")
|
||||
res = cur.fetchall()
|
||||
print(res)
|
||||
cleanup_conn(conn, cur)
|
||||
|
||||
|
||||
def test_admin_db_access():
|
||||
conn, cur = connect_db(admin=True)
|
||||
|
||||
cur.execute("SHOW POOLS")
|
||||
res = cur.fetchall()
|
||||
print(res)
|
||||
cleanup_conn(conn, cur)
|
||||
|
||||
|
||||
def test_shutdown_logic():
|
||||
|
||||
@@ -86,17 +14,17 @@ def test_shutdown_logic():
|
||||
# NO ACTIVE QUERIES SIGINT HANDLING
|
||||
|
||||
# Start pgcat
|
||||
pgcat_start()
|
||||
utils.pgcat_start()
|
||||
|
||||
# Create client connection and send query (not in transaction)
|
||||
conn, cur = connect_db()
|
||||
conn, cur = utils.connect_db()
|
||||
|
||||
cur.execute("BEGIN;")
|
||||
cur.execute("SELECT 1;")
|
||||
cur.execute("COMMIT;")
|
||||
|
||||
# Send sigint to pgcat
|
||||
pg_cat_send_signal(signal.SIGINT)
|
||||
utils.pg_cat_send_signal(signal.SIGINT)
|
||||
time.sleep(1)
|
||||
|
||||
# Check that any new queries fail after sigint since server should close with no active transactions
|
||||
@@ -108,18 +36,18 @@ def test_shutdown_logic():
|
||||
# Fail if query execution succeeded
|
||||
raise Exception("Server not closed after sigint")
|
||||
|
||||
cleanup_conn(conn, cur)
|
||||
pg_cat_send_signal(signal.SIGTERM)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||
|
||||
# - - - - - - - - - - - - - - - - - -
|
||||
# NO ACTIVE QUERIES ADMIN SHUTDOWN COMMAND
|
||||
|
||||
# Start pgcat
|
||||
pgcat_start()
|
||||
utils.pgcat_start()
|
||||
|
||||
# Create client connection and begin transaction
|
||||
conn, cur = connect_db()
|
||||
admin_conn, admin_cur = connect_db(admin=True)
|
||||
conn, cur = utils.connect_db()
|
||||
admin_conn, admin_cur = utils.connect_db(admin=True)
|
||||
|
||||
cur.execute("BEGIN;")
|
||||
cur.execute("SELECT 1;")
|
||||
@@ -138,24 +66,24 @@ def test_shutdown_logic():
|
||||
# Fail if query execution succeeded
|
||||
raise Exception("Server not closed after sigint")
|
||||
|
||||
cleanup_conn(conn, cur)
|
||||
cleanup_conn(admin_conn, admin_cur)
|
||||
pg_cat_send_signal(signal.SIGTERM)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
utils.cleanup_conn(admin_conn, admin_cur)
|
||||
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||
|
||||
# - - - - - - - - - - - - - - - - - -
|
||||
# HANDLE TRANSACTION WITH SIGINT
|
||||
|
||||
# Start pgcat
|
||||
pgcat_start()
|
||||
utils.pgcat_start()
|
||||
|
||||
# Create client connection and begin transaction
|
||||
conn, cur = connect_db()
|
||||
conn, cur = utils.connect_db()
|
||||
|
||||
cur.execute("BEGIN;")
|
||||
cur.execute("SELECT 1;")
|
||||
|
||||
# Send sigint to pgcat while still in transaction
|
||||
pg_cat_send_signal(signal.SIGINT)
|
||||
utils.pg_cat_send_signal(signal.SIGINT)
|
||||
time.sleep(1)
|
||||
|
||||
# Check that any new queries succeed after sigint since server should still allow transaction to complete
|
||||
@@ -165,18 +93,18 @@ def test_shutdown_logic():
|
||||
# Fail if query fails since server closed
|
||||
raise Exception("Server closed while in transaction", e.pgerror)
|
||||
|
||||
cleanup_conn(conn, cur)
|
||||
pg_cat_send_signal(signal.SIGTERM)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||
|
||||
# - - - - - - - - - - - - - - - - - -
|
||||
# HANDLE TRANSACTION WITH ADMIN SHUTDOWN COMMAND
|
||||
|
||||
# Start pgcat
|
||||
pgcat_start()
|
||||
utils.pgcat_start()
|
||||
|
||||
# Create client connection and begin transaction
|
||||
conn, cur = connect_db()
|
||||
admin_conn, admin_cur = connect_db(admin=True)
|
||||
conn, cur = utils.connect_db()
|
||||
admin_conn, admin_cur = utils.connect_db(admin=True)
|
||||
|
||||
cur.execute("BEGIN;")
|
||||
cur.execute("SELECT 1;")
|
||||
@@ -194,30 +122,30 @@ def test_shutdown_logic():
|
||||
# Fail if query fails since server closed
|
||||
raise Exception("Server closed while in transaction", e.pgerror)
|
||||
|
||||
cleanup_conn(conn, cur)
|
||||
cleanup_conn(admin_conn, admin_cur)
|
||||
pg_cat_send_signal(signal.SIGTERM)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
utils.cleanup_conn(admin_conn, admin_cur)
|
||||
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||
|
||||
# - - - - - - - - - - - - - - - - - -
|
||||
# NO NEW NON-ADMIN CONNECTIONS DURING SHUTDOWN
|
||||
# Start pgcat
|
||||
pgcat_start()
|
||||
utils.pgcat_start()
|
||||
|
||||
# Create client connection and begin transaction
|
||||
transaction_conn, transaction_cur = connect_db()
|
||||
transaction_conn, transaction_cur = utils.connect_db()
|
||||
|
||||
transaction_cur.execute("BEGIN;")
|
||||
transaction_cur.execute("SELECT 1;")
|
||||
|
||||
# Send sigint to pgcat while still in transaction
|
||||
pg_cat_send_signal(signal.SIGINT)
|
||||
utils.pg_cat_send_signal(signal.SIGINT)
|
||||
time.sleep(1)
|
||||
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
conn, cur = connect_db()
|
||||
conn, cur = utils.connect_db()
|
||||
cur.execute("SELECT 1;")
|
||||
cleanup_conn(conn, cur)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
except psycopg2.OperationalError as e:
|
||||
time_taken = time.perf_counter() - start
|
||||
if time_taken > 0.1:
|
||||
@@ -227,49 +155,49 @@ def test_shutdown_logic():
|
||||
else:
|
||||
raise Exception("Able connect to database during shutdown")
|
||||
|
||||
cleanup_conn(transaction_conn, transaction_cur)
|
||||
pg_cat_send_signal(signal.SIGTERM)
|
||||
utils.cleanup_conn(transaction_conn, transaction_cur)
|
||||
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||
|
||||
# - - - - - - - - - - - - - - - - - -
|
||||
# ALLOW NEW ADMIN CONNECTIONS DURING SHUTDOWN
|
||||
# Start pgcat
|
||||
pgcat_start()
|
||||
utils.pgcat_start()
|
||||
|
||||
# Create client connection and begin transaction
|
||||
transaction_conn, transaction_cur = connect_db()
|
||||
transaction_conn, transaction_cur = utils.connect_db()
|
||||
|
||||
transaction_cur.execute("BEGIN;")
|
||||
transaction_cur.execute("SELECT 1;")
|
||||
|
||||
# Send sigint to pgcat while still in transaction
|
||||
pg_cat_send_signal(signal.SIGINT)
|
||||
utils.pg_cat_send_signal(signal.SIGINT)
|
||||
time.sleep(1)
|
||||
|
||||
try:
|
||||
conn, cur = connect_db(admin=True)
|
||||
conn, cur = utils.connect_db(admin=True)
|
||||
cur.execute("SHOW DATABASES;")
|
||||
cleanup_conn(conn, cur)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
except psycopg2.OperationalError as e:
|
||||
raise Exception(e)
|
||||
|
||||
cleanup_conn(transaction_conn, transaction_cur)
|
||||
pg_cat_send_signal(signal.SIGTERM)
|
||||
utils.cleanup_conn(transaction_conn, transaction_cur)
|
||||
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||
|
||||
# - - - - - - - - - - - - - - - - - -
|
||||
# ADMIN CONNECTIONS CONTINUING TO WORK AFTER SHUTDOWN
|
||||
# Start pgcat
|
||||
pgcat_start()
|
||||
utils.pgcat_start()
|
||||
|
||||
# Create client connection and begin transaction
|
||||
transaction_conn, transaction_cur = connect_db()
|
||||
transaction_conn, transaction_cur = utils.connect_db()
|
||||
transaction_cur.execute("BEGIN;")
|
||||
transaction_cur.execute("SELECT 1;")
|
||||
|
||||
admin_conn, admin_cur = connect_db(admin=True)
|
||||
admin_conn, admin_cur = utils.connect_db(admin=True)
|
||||
admin_cur.execute("SHOW DATABASES;")
|
||||
|
||||
# Send sigint to pgcat while still in transaction
|
||||
pg_cat_send_signal(signal.SIGINT)
|
||||
utils.pg_cat_send_signal(signal.SIGINT)
|
||||
time.sleep(1)
|
||||
|
||||
try:
|
||||
@@ -277,24 +205,24 @@ def test_shutdown_logic():
|
||||
except psycopg2.OperationalError as e:
|
||||
raise Exception("Could not execute admin command:", e)
|
||||
|
||||
cleanup_conn(transaction_conn, transaction_cur)
|
||||
cleanup_conn(admin_conn, admin_cur)
|
||||
pg_cat_send_signal(signal.SIGTERM)
|
||||
utils.cleanup_conn(transaction_conn, transaction_cur)
|
||||
utils.cleanup_conn(admin_conn, admin_cur)
|
||||
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||
|
||||
# - - - - - - - - - - - - - - - - - -
|
||||
# HANDLE SHUTDOWN TIMEOUT WITH SIGINT
|
||||
|
||||
# Start pgcat
|
||||
pgcat_start()
|
||||
utils.pgcat_start()
|
||||
|
||||
# Create client connection and begin transaction, which should prevent server shutdown unless shutdown timeout is reached
|
||||
conn, cur = connect_db()
|
||||
conn, cur = utils.connect_db()
|
||||
|
||||
cur.execute("BEGIN;")
|
||||
cur.execute("SELECT 1;")
|
||||
|
||||
# Send sigint to pgcat while still in transaction
|
||||
pg_cat_send_signal(signal.SIGINT)
|
||||
utils.pg_cat_send_signal(signal.SIGINT)
|
||||
|
||||
# pgcat shutdown timeout is set to SHUTDOWN_TIMEOUT seconds, so we sleep for SHUTDOWN_TIMEOUT + 1 seconds
|
||||
time.sleep(SHUTDOWN_TIMEOUT + 1)
|
||||
@@ -308,12 +236,7 @@ def test_shutdown_logic():
|
||||
# Fail if query execution succeeded
|
||||
raise Exception("Server not closed after sigint and expected timeout")
|
||||
|
||||
cleanup_conn(conn, cur)
|
||||
pg_cat_send_signal(signal.SIGTERM)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||
|
||||
# - - - - - - - - - - - - - - - - - -
|
||||
|
||||
|
||||
test_normal_db_access()
|
||||
test_admin_db_access()
|
||||
test_shutdown_logic()
|
||||
110
tests/python/utils.py
Normal file
110
tests/python/utils.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import os
|
||||
import signal
|
||||
import time
|
||||
from typing import Tuple
|
||||
import tempfile
|
||||
|
||||
import psutil
|
||||
import psycopg2
|
||||
|
||||
PGCAT_HOST = "127.0.0.1"
|
||||
PGCAT_PORT = "6432"
|
||||
|
||||
|
||||
def _pgcat_start(config_path: str):
|
||||
pg_cat_send_signal(signal.SIGTERM)
|
||||
os.system(f"./target/debug/pgcat {config_path} &")
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def pgcat_start():
|
||||
_pgcat_start(config_path='.circleci/pgcat.toml')
|
||||
|
||||
|
||||
def pgcat_generic_start(config: str):
|
||||
tmp = tempfile.NamedTemporaryFile()
|
||||
with open(tmp.name, 'w') as f:
|
||||
f.write(config)
|
||||
_pgcat_start(config_path=tmp.name)
|
||||
|
||||
|
||||
def glauth_send_signal(signal: signal.Signals):
|
||||
try:
|
||||
for proc in psutil.process_iter(["pid", "name"]):
|
||||
if proc.name() == "glauth":
|
||||
os.kill(proc.pid, signal)
|
||||
except Exception as e:
|
||||
# The process can be gone when we send this signal
|
||||
print(e)
|
||||
|
||||
if signal == signal.SIGTERM:
|
||||
# Returns 0 if pgcat process exists
|
||||
time.sleep(2)
|
||||
if not os.system('pgrep glauth'):
|
||||
raise Exception("glauth not closed after SIGTERM")
|
||||
|
||||
|
||||
def pg_cat_send_signal(signal: signal.Signals):
|
||||
try:
|
||||
for proc in psutil.process_iter(["pid", "name"]):
|
||||
if "pgcat" == proc.name():
|
||||
os.kill(proc.pid, signal)
|
||||
except Exception as e:
|
||||
# The process can be gone when we send this signal
|
||||
print(e)
|
||||
|
||||
if signal == signal.SIGTERM:
|
||||
# Returns 0 if pgcat process exists
|
||||
time.sleep(2)
|
||||
if not os.system('pgrep pgcat'):
|
||||
raise Exception("pgcat not closed after SIGTERM")
|
||||
|
||||
|
||||
def connect_db(
|
||||
autocommit: bool = True,
|
||||
admin: bool = False,
|
||||
) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]:
|
||||
|
||||
if admin:
|
||||
user = "admin_user"
|
||||
password = "admin_pass"
|
||||
db = "pgcat"
|
||||
else:
|
||||
user = "sharding_user"
|
||||
password = "sharding_user"
|
||||
db = "sharded_db"
|
||||
|
||||
conn = psycopg2.connect(
|
||||
f"postgres://{user}:{password}@{PGCAT_HOST}:{PGCAT_PORT}/{db}?application_name=testing_pgcat",
|
||||
connect_timeout=2,
|
||||
)
|
||||
conn.autocommit = autocommit
|
||||
cur = conn.cursor()
|
||||
|
||||
return (conn, cur)
|
||||
|
||||
def connect_db_trust(
|
||||
autocommit: bool = True,
|
||||
admin: bool = False,
|
||||
) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]:
|
||||
|
||||
if admin:
|
||||
user = "admin_user"
|
||||
db = "pgcat"
|
||||
else:
|
||||
user = "sharding_user"
|
||||
db = "sharded_db"
|
||||
|
||||
conn = psycopg2.connect(
|
||||
f"postgres://{user}@{PGCAT_HOST}:{PGCAT_PORT}/{db}?application_name=testing_pgcat",
|
||||
connect_timeout=2,
|
||||
)
|
||||
conn.autocommit = autocommit
|
||||
cur = conn.cursor()
|
||||
|
||||
return (conn, cur)
|
||||
|
||||
|
||||
def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.cursor):
|
||||
cur.close()
|
||||
conn.close()
|
||||
@@ -1,145 +0,0 @@
|
||||
|
||||
class PostgresMessage
|
||||
# Base class for common functionality
|
||||
|
||||
def encode_string(str)
|
||||
"#{str}\0" # Encode a string with a null terminator
|
||||
end
|
||||
|
||||
def encode_int16(value)
|
||||
[value].pack('n') # Encode an Int16
|
||||
end
|
||||
|
||||
def encode_int32(value)
|
||||
[value].pack('N') # Encode an Int32
|
||||
end
|
||||
|
||||
def message_prefix(type, length)
|
||||
"#{type}#{encode_int32(length)}" # Message type and length prefix
|
||||
end
|
||||
end
|
||||
|
||||
class SimpleQueryMessage < PostgresMessage
|
||||
attr_accessor :query
|
||||
|
||||
def initialize(query = "")
|
||||
@query = query
|
||||
end
|
||||
|
||||
def to_bytes
|
||||
query_bytes = encode_string(@query)
|
||||
length = 4 + query_bytes.size # Length includes 4 bytes for length itself
|
||||
message_prefix('Q', length) + query_bytes
|
||||
end
|
||||
end
|
||||
|
||||
class ParseMessage < PostgresMessage
|
||||
attr_accessor :statement_name, :query, :parameter_types
|
||||
|
||||
def initialize(statement_name = "", query = "", parameter_types = [])
|
||||
@statement_name = statement_name
|
||||
@query = query
|
||||
@parameter_types = parameter_types
|
||||
end
|
||||
|
||||
def to_bytes
|
||||
statement_name_bytes = encode_string(@statement_name)
|
||||
query_bytes = encode_string(@query)
|
||||
parameter_types_bytes = @parameter_types.pack('N*')
|
||||
|
||||
length = 4 + statement_name_bytes.size + query_bytes.size + 2 + parameter_types_bytes.size
|
||||
message_prefix('P', length) + statement_name_bytes + query_bytes + encode_int16(@parameter_types.size) + parameter_types_bytes
|
||||
end
|
||||
end
|
||||
|
||||
class BindMessage < PostgresMessage
|
||||
attr_accessor :portal_name, :statement_name, :parameter_format_codes, :parameters, :result_column_format_codes
|
||||
|
||||
def initialize(portal_name = "", statement_name = "", parameter_format_codes = [], parameters = [], result_column_format_codes = [])
|
||||
@portal_name = portal_name
|
||||
@statement_name = statement_name
|
||||
@parameter_format_codes = parameter_format_codes
|
||||
@parameters = parameters
|
||||
@result_column_format_codes = result_column_format_codes
|
||||
end
|
||||
|
||||
def to_bytes
|
||||
portal_name_bytes = encode_string(@portal_name)
|
||||
statement_name_bytes = encode_string(@statement_name)
|
||||
parameter_format_codes_bytes = @parameter_format_codes.pack('n*')
|
||||
|
||||
parameters_bytes = @parameters.map do |param|
|
||||
if param.nil?
|
||||
encode_int32(-1)
|
||||
else
|
||||
encode_int32(param.bytesize) + param
|
||||
end
|
||||
end.join
|
||||
|
||||
result_column_format_codes_bytes = @result_column_format_codes.pack('n*')
|
||||
|
||||
length = 4 + portal_name_bytes.size + statement_name_bytes.size + 2 + parameter_format_codes_bytes.size + 2 + parameters_bytes.size + 2 + result_column_format_codes_bytes.size
|
||||
message_prefix('B', length) + portal_name_bytes + statement_name_bytes + encode_int16(@parameter_format_codes.size) + parameter_format_codes_bytes + encode_int16(@parameters.size) + parameters_bytes + encode_int16(@result_column_format_codes.size) + result_column_format_codes_bytes
|
||||
end
|
||||
end
|
||||
|
||||
class DescribeMessage < PostgresMessage
|
||||
attr_accessor :type, :name
|
||||
|
||||
def initialize(type = 'S', name = "")
|
||||
@type = type
|
||||
@name = name
|
||||
end
|
||||
|
||||
def to_bytes
|
||||
name_bytes = encode_string(@name)
|
||||
length = 4 + 1 + name_bytes.size
|
||||
message_prefix('D', length) + @type + name_bytes
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
class ExecuteMessage < PostgresMessage
|
||||
attr_accessor :portal_name, :max_rows
|
||||
|
||||
def initialize(portal_name = "", max_rows = 0)
|
||||
@portal_name = portal_name
|
||||
@max_rows = max_rows
|
||||
end
|
||||
|
||||
def to_bytes
|
||||
portal_name_bytes = encode_string(@portal_name)
|
||||
length = 4 + portal_name_bytes.size + 4
|
||||
message_prefix('E', length) + portal_name_bytes + encode_int32(@max_rows)
|
||||
end
|
||||
end
|
||||
|
||||
class FlushMessage < PostgresMessage
|
||||
def to_bytes
|
||||
length = 4
|
||||
message_prefix('H', length)
|
||||
end
|
||||
end
|
||||
|
||||
class SyncMessage < PostgresMessage
|
||||
def to_bytes
|
||||
length = 4
|
||||
message_prefix('S', length)
|
||||
end
|
||||
end
|
||||
|
||||
class CloseMessage < PostgresMessage
|
||||
attr_accessor :type, :name
|
||||
|
||||
def initialize(type = 'S', name = "")
|
||||
@type = type
|
||||
@name = name
|
||||
end
|
||||
|
||||
def to_bytes
|
||||
name_bytes = encode_string(@name)
|
||||
length = 4 + 1 + name_bytes.size
|
||||
message_prefix('C', length) + @type + name_bytes
|
||||
end
|
||||
end
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
require 'socket'
|
||||
require 'digest/md5'
|
||||
require_relative 'frontend_messages'
|
||||
|
||||
BACKEND_MESSAGE_CODES = {
|
||||
'Z' => "ReadyForQuery",
|
||||
@@ -19,13 +18,9 @@ class PostgresSocket
|
||||
@host = host
|
||||
@socket = TCPSocket.new @host, @port
|
||||
@parameters = {}
|
||||
@verbose = false
|
||||
@verbose = true
|
||||
end
|
||||
|
||||
def send_message(message)
|
||||
@socket.write(message.to_bytes)
|
||||
end
|
||||
|
||||
def send_md5_password_message(username, password, salt)
|
||||
m = Digest::MD5.hexdigest(password + username)
|
||||
m = Digest::MD5.hexdigest(m + salt.map(&:chr).join(""))
|
||||
@@ -118,6 +113,107 @@ class PostgresSocket
|
||||
log "[F] Sent CancelRequest message"
|
||||
end
|
||||
|
||||
def send_query_message(query)
|
||||
query_size = query.length
|
||||
message_size = 1 + 4 + query_size
|
||||
message = []
|
||||
message << "Q".ord
|
||||
message << [message_size].pack('l>').unpack('CCCC') # 4
|
||||
message << query.split('').map(&:ord) # 2, 11
|
||||
message << 0 # 1, 12
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent Q message (#{query})"
|
||||
end
|
||||
|
||||
def send_parse_message(query)
|
||||
query_size = query.length
|
||||
message_size = 2 + 2 + 4 + query_size
|
||||
message = []
|
||||
message << "P".ord
|
||||
message << [message_size].pack('l>').unpack('CCCC') # 4
|
||||
message << 0 # unnamed statement
|
||||
message << query.split('').map(&:ord) # 2, 11
|
||||
message << 0 # 1, 12
|
||||
message << [0, 0]
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent P message (#{query})"
|
||||
end
|
||||
|
||||
def send_bind_message
|
||||
message = []
|
||||
message << "B".ord
|
||||
message << [12].pack('l>').unpack('CCCC') # 4
|
||||
message << 0 # unnamed statement
|
||||
message << 0 # unnamed statement
|
||||
message << [0, 0] # 2
|
||||
message << [0, 0] # 2
|
||||
message << [0, 0] # 2
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent B message"
|
||||
end
|
||||
|
||||
def send_describe_message(mode)
|
||||
message = []
|
||||
message << "D".ord
|
||||
message << [6].pack('l>').unpack('CCCC') # 4
|
||||
message << mode.ord
|
||||
message << 0 # unnamed statement
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent D message"
|
||||
end
|
||||
|
||||
def send_execute_message(limit=0)
|
||||
message = []
|
||||
message << "E".ord
|
||||
message << [9].pack('l>').unpack('CCCC') # 4
|
||||
message << 0 # unnamed statement
|
||||
message << [limit].pack('l>').unpack('CCCC') # 4
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent E message"
|
||||
end
|
||||
|
||||
def send_sync_message
|
||||
message = []
|
||||
message << "S".ord
|
||||
message << [4].pack('l>').unpack('CCCC') # 4
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent S message"
|
||||
end
|
||||
|
||||
def send_copydone_message
|
||||
message = []
|
||||
message << "c".ord
|
||||
message << [4].pack('l>').unpack('CCCC') # 4
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent c message"
|
||||
end
|
||||
|
||||
def send_copyfail_message
|
||||
message = []
|
||||
message << "f".ord
|
||||
message << [5].pack('l>').unpack('CCCC') # 4
|
||||
message << 0
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent f message"
|
||||
end
|
||||
|
||||
def send_flush_message
|
||||
message = []
|
||||
message << "H".ord
|
||||
message << [4].pack('l>').unpack('CCCC') # 4
|
||||
message.flatten!
|
||||
@socket.write(message.flatten.pack('C*'))
|
||||
log "[F] Sent H message"
|
||||
end
|
||||
|
||||
def read_from_server()
|
||||
output_messages = []
|
||||
retry_count = 0
|
||||
|
||||
@@ -16,14 +16,10 @@ describe "Portocol handling" do
|
||||
end
|
||||
|
||||
def run_comparison(sequence, socket_a, socket_b)
|
||||
sequence.each do |msg|
|
||||
if msg.is_a?(Symbol)
|
||||
socket_a.send(msg)
|
||||
socket_b.send(msg)
|
||||
else
|
||||
socket_a.send_message(msg)
|
||||
socket_b.send_message(msg)
|
||||
end
|
||||
sequence.each do |msg, *args|
|
||||
socket_a.send(msg, *args)
|
||||
socket_b.send(msg, *args)
|
||||
|
||||
compare_messages(
|
||||
socket_a.read_from_server,
|
||||
socket_b.read_from_server
|
||||
@@ -87,9 +83,9 @@ describe "Portocol handling" do
|
||||
|
||||
context "Cancel Query" do
|
||||
let(:sequence) {
|
||||
[
|
||||
SimpleQueryMessage.new("SELECT pg_sleep(5)"),
|
||||
:cancel_query
|
||||
[
|
||||
[:send_query_message, "SELECT pg_sleep(5)"],
|
||||
[:cancel_query]
|
||||
]
|
||||
}
|
||||
|
||||
@@ -99,12 +95,12 @@ describe "Portocol handling" do
|
||||
xcontext "Simple query after parse" do
|
||||
let(:sequence) {
|
||||
[
|
||||
ParseMessage.new("", "SELECT 5", []),
|
||||
SimpleQueryMessage.new("SELECT 1"),
|
||||
BindMessage.new("", "", [], [], [0]),
|
||||
DescribeMessage.new("P", ""),
|
||||
ExecuteMessage.new("", 1),
|
||||
SyncMessage.new
|
||||
[:send_parse_message, "SELECT 5"],
|
||||
[:send_query_message, "SELECT 1"],
|
||||
[:send_bind_message],
|
||||
[:send_describe_message, "P"],
|
||||
[:send_execute_message],
|
||||
[:send_sync_message],
|
||||
]
|
||||
}
|
||||
|
||||
@@ -115,8 +111,8 @@ describe "Portocol handling" do
|
||||
xcontext "Flush message" do
|
||||
let(:sequence) {
|
||||
[
|
||||
ParseMessage.new("", "SELECT 1", []),
|
||||
FlushMessage.new
|
||||
[:send_parse_message, "SELECT 1"],
|
||||
[:send_flush_message]
|
||||
]
|
||||
}
|
||||
|
||||
@@ -126,7 +122,9 @@ describe "Portocol handling" do
|
||||
|
||||
xcontext "Bind without parse" do
|
||||
let(:sequence) {
|
||||
[BindMessage.new("", "", [], [], [0])]
|
||||
[
|
||||
[:send_bind_message]
|
||||
]
|
||||
}
|
||||
# This is known to fail.
|
||||
# Server responds immediately, Proxy buffers the message
|
||||
@@ -135,155 +133,23 @@ describe "Portocol handling" do
|
||||
|
||||
context "Simple message" do
|
||||
let(:sequence) {
|
||||
[SimpleQueryMessage.new("SELECT 1")]
|
||||
[[:send_query_message, "SELECT 1"]]
|
||||
}
|
||||
|
||||
it_behaves_like "at parity with database"
|
||||
end
|
||||
|
||||
10.times do |i|
|
||||
context "Extended protocol" do
|
||||
let(:sequence) {
|
||||
[
|
||||
ParseMessage.new("", "SELECT 1", []),
|
||||
BindMessage.new("", "", [], [], [0]),
|
||||
DescribeMessage.new("S", ""),
|
||||
ExecuteMessage.new("", 1),
|
||||
SyncMessage.new
|
||||
]
|
||||
}
|
||||
|
||||
it_behaves_like "at parity with database"
|
||||
end
|
||||
context "Extended protocol" do
|
||||
let(:sequence) {
|
||||
[
|
||||
[:send_parse_message, "SELECT 1"],
|
||||
[:send_bind_message],
|
||||
[:send_describe_message, "P"],
|
||||
[:send_execute_message],
|
||||
[:send_sync_message],
|
||||
]
|
||||
}
|
||||
|
||||
it_behaves_like "at parity with database"
|
||||
end
|
||||
|
||||
describe "Protocol-level prepared statements" do
|
||||
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 1, "transaction") }
|
||||
before do
|
||||
q_sock = PostgresSocket.new('localhost', processes.pgcat.port)
|
||||
q_sock.send_startup_message("sharding_user", "sharded_db", "sharding_user")
|
||||
table_query = "CREATE TABLE IF NOT EXISTS employees (employee_id SERIAL PRIMARY KEY, salary NUMERIC(10, 2) CHECK (salary > 0));"
|
||||
q_sock.send_message(SimpleQueryMessage.new(table_query))
|
||||
q_sock.close
|
||||
|
||||
current_configs = processes.pgcat.current_config
|
||||
current_configs["pools"]["sharded_db"]["prepared_statements_cache_size"] = 500
|
||||
processes.pgcat.update_config(current_configs)
|
||||
processes.pgcat.reload_config
|
||||
end
|
||||
after do
|
||||
q_sock = PostgresSocket.new('localhost', processes.pgcat.port)
|
||||
q_sock.send_startup_message("sharding_user", "sharded_db", "sharding_user")
|
||||
table_query = "DROP TABLE IF EXISTS employees;"
|
||||
q_sock.send_message(SimpleQueryMessage.new(table_query))
|
||||
q_sock.close
|
||||
end
|
||||
|
||||
context "When unnamed prepared statements are used" do
|
||||
it "does not cache them" do
|
||||
socket = PostgresSocket.new('localhost', processes.pgcat.port)
|
||||
socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
|
||||
|
||||
socket.send_message(SimpleQueryMessage.new("DISCARD ALL"))
|
||||
socket.read_from_server
|
||||
|
||||
10.times do |i|
|
||||
socket.send_message(ParseMessage.new("", "SELECT #{i}", []))
|
||||
socket.send_message(BindMessage.new("", "", [], [], [0]))
|
||||
socket.send_message(DescribeMessage.new("S", ""))
|
||||
socket.send_message(ExecuteMessage.new("", 1))
|
||||
socket.send_message(SyncMessage.new)
|
||||
socket.read_from_server
|
||||
end
|
||||
|
||||
socket.send_message(SimpleQueryMessage.new("SELECT name, statement, prepare_time, parameter_types FROM pg_prepared_statements"))
|
||||
result = socket.read_from_server
|
||||
number_of_saved_statements = result.count { |m| m[:code] == 'D' }
|
||||
expect(number_of_saved_statements).to eq(0)
|
||||
end
|
||||
end
|
||||
|
||||
context "When named prepared statements are used" do
|
||||
it "caches them" do
|
||||
socket = PostgresSocket.new('localhost', processes.pgcat.port)
|
||||
socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
|
||||
|
||||
socket.send_message(SimpleQueryMessage.new("DISCARD ALL"))
|
||||
socket.read_from_server
|
||||
|
||||
3.times do
|
||||
socket.send_message(ParseMessage.new("my_query", "SELECT * FROM employees WHERE employee_id in ($1,$2,$3)", [0,0,0]))
|
||||
socket.send_message(BindMessage.new("", "my_query", [0,0,0], [0,0,0].map(&:to_s), [0,0,0,0,0,0]))
|
||||
socket.send_message(SyncMessage.new)
|
||||
socket.read_from_server
|
||||
end
|
||||
|
||||
3.times do
|
||||
socket.send_message(ParseMessage.new("my_other_query", "SELECT * FROM employees WHERE salary in ($1,$2,$3)", [0,0,0]))
|
||||
socket.send_message(BindMessage.new("", "my_other_query", [0,0,0], [0,0,0].map(&:to_s), [0,0,0,0,0,0]))
|
||||
socket.send_message(SyncMessage.new)
|
||||
socket.read_from_server
|
||||
end
|
||||
|
||||
socket.send_message(SimpleQueryMessage.new("SELECT name, statement, prepare_time, parameter_types FROM pg_prepared_statements"))
|
||||
result = socket.read_from_server
|
||||
number_of_saved_statements = result.count { |m| m[:code] == 'D' }
|
||||
expect(number_of_saved_statements).to eq(2)
|
||||
end
|
||||
end
|
||||
|
||||
context "When DISCARD ALL/DEALLOCATE ALL are called" do
|
||||
it "resets server and client caches" do
|
||||
socket = PostgresSocket.new('localhost', processes.pgcat.port)
|
||||
socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
|
||||
|
||||
20.times do |i|
|
||||
socket.send_message(ParseMessage.new("my_query_#{i}", "SELECT * FROM employees WHERE employee_id in ($1,$2,$3)", [0,0,0]))
|
||||
end
|
||||
|
||||
20.times do |i|
|
||||
socket.send_message(BindMessage.new("", "my_query_#{i}", [0,0,0], [0,0,0].map(&:to_s), [0,0]))
|
||||
end
|
||||
|
||||
socket.send_message(SyncMessage.new)
|
||||
socket.read_from_server
|
||||
|
||||
socket.send_message(SimpleQueryMessage.new("DISCARD ALL"))
|
||||
socket.read_from_server
|
||||
responses = []
|
||||
4.times do |i|
|
||||
socket.send_message(ParseMessage.new("my_query_#{i}", "SELECT * FROM employees WHERE employee_id in ($1,$2,$3)", [0,0,0]))
|
||||
socket.send_message(BindMessage.new("", "my_query_#{i}", [0,0,0], [0,0,0].map(&:to_s), [0,0]))
|
||||
socket.send_message(SyncMessage.new)
|
||||
|
||||
responses += socket.read_from_server
|
||||
end
|
||||
|
||||
errors = responses.select { |message| message[:code] == 'E' }
|
||||
error_message = errors.map { |message| message[:bytes].map(&:chr).join("") }.join("\n")
|
||||
raise StandardError, "Encountered the following errors: #{error_message}" if errors.length > 0
|
||||
end
|
||||
end
|
||||
|
||||
context "Maximum number of bound paramters" do
|
||||
it "does not crash" do
|
||||
test_socket = PostgresSocket.new('localhost', processes.pgcat.port)
|
||||
test_socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
|
||||
|
||||
types = Array.new(65_535) { |i| 0 }
|
||||
|
||||
params = Array.new(65_535) { |i| "$#{i+1}" }.join(",")
|
||||
test_socket.send_message(ParseMessage.new("my_query", "SELECT * FROM employees WHERE employee_id in (#{params})", types))
|
||||
|
||||
test_socket.send_message(BindMessage.new("my_query", "my_query", types, types.map(&:to_s), types))
|
||||
|
||||
test_socket.send_message(SyncMessage.new)
|
||||
|
||||
# If the proxy crashes, this will raise an error
|
||||
expect { test_socket.read_from_server }.to_not raise_error
|
||||
|
||||
test_socket.close
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
Reference in New Issue
Block a user