Compare commits

...

5 Commits

Author SHA1 Message Date
Mostafa
3fc9e5dec1 Merge branch 'main' of github.com:postgresml/pgcat into mostafa_fix_prepared_stmts 2024-09-03 18:11:32 -05:00
Mostafa
f7c5c0faf9 fix bind 2024-09-01 16:14:44 -05:00
Mostafa
982d03c374 fix syntax 2024-09-01 15:41:33 -05:00
Mostafa
686b7ca7c5 Fixes 2024-09-01 15:31:27 -05:00
Mostafa
7c55bf78fe Add failing tests 2024-09-01 14:39:05 -05:00
6 changed files with 361 additions and 171 deletions

View File

@@ -1729,14 +1729,13 @@ where
/// and also the pool's statement cache. Add it to extended protocol data. /// and also the pool's statement cache. Add it to extended protocol data.
fn buffer_parse(&mut self, message: BytesMut, pool: &ConnectionPool) -> Result<(), Error> { fn buffer_parse(&mut self, message: BytesMut, pool: &ConnectionPool) -> Result<(), Error> {
// Avoid parsing if prepared statements not enabled // Avoid parsing if prepared statements not enabled
if !self.prepared_statements_enabled { let client_given_name = Parse::get_name(&message)?;
if !self.prepared_statements_enabled || client_given_name.is_empty() {
debug!("Anonymous parse message"); debug!("Anonymous parse message");
self.extended_protocol_data_buffer self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_parse(message, None)); .push_back(ExtendedProtocolData::create_new_parse(message, None));
return Ok(()); return Ok(());
} }
let client_given_name = Parse::get_name(&message)?;
let parse: Parse = (&message).try_into()?; let parse: Parse = (&message).try_into()?;
// Compute the hash of the parse statement // Compute the hash of the parse statement
@@ -1774,15 +1773,14 @@ where
/// saved in the client cache. /// saved in the client cache.
async fn buffer_bind(&mut self, message: BytesMut) -> Result<(), Error> { async fn buffer_bind(&mut self, message: BytesMut) -> Result<(), Error> {
// Avoid parsing if prepared statements not enabled // Avoid parsing if prepared statements not enabled
if !self.prepared_statements_enabled { let client_given_name = Bind::get_name(&message)?;
if !self.prepared_statements_enabled || client_given_name.is_empty() {
debug!("Anonymous bind message"); debug!("Anonymous bind message");
self.extended_protocol_data_buffer self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_bind(message, None)); .push_back(ExtendedProtocolData::create_new_bind(message, None));
return Ok(()); return Ok(());
} }
let client_given_name = Bind::get_name(&message)?;
match self.prepared_statements.get(&client_given_name) { match self.prepared_statements.get(&client_given_name) {
Some((rewritten_parse, _)) => { Some((rewritten_parse, _)) => {
let message = Bind::rename(message, &rewritten_parse.name)?; let message = Bind::rename(message, &rewritten_parse.name)?;
@@ -1834,7 +1832,8 @@ where
} }
let describe: Describe = (&message).try_into()?; let describe: Describe = (&message).try_into()?;
if describe.target == 'P' { let client_given_name = describe.statement_name.clone();
if describe.target == 'P' || client_given_name.is_empty() {
debug!("Portal describe message"); debug!("Portal describe message");
self.extended_protocol_data_buffer self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_describe(message, None)); .push_back(ExtendedProtocolData::create_new_describe(message, None));
@@ -1842,8 +1841,6 @@ where
return Ok(()); return Ok(());
} }
let client_given_name = describe.statement_name.clone();
match self.prepared_statements.get(&client_given_name) { match self.prepared_statements.get(&client_given_name) {
Some((rewritten_parse, _)) => { Some((rewritten_parse, _)) => {
let describe = describe.rename(&rewritten_parse.name); let describe = describe.rename(&rewritten_parse.name);

View File

@@ -821,10 +821,10 @@ impl ExtendedProtocolData {
pub struct Parse { pub struct Parse {
code: char, code: char,
#[allow(dead_code)] #[allow(dead_code)]
len: i32, len: u32,
pub name: String, pub name: String,
query: String, query: String,
num_params: i16, num_params: u16,
param_types: Vec<i32>, param_types: Vec<i32>,
} }
@@ -834,12 +834,11 @@ impl TryFrom<&BytesMut> for Parse {
fn try_from(buf: &BytesMut) -> Result<Parse, Error> { fn try_from(buf: &BytesMut) -> Result<Parse, Error> {
let mut cursor = Cursor::new(buf); let mut cursor = Cursor::new(buf);
let code = cursor.get_u8() as char; let code = cursor.get_u8() as char;
let len = cursor.get_i32(); let len = cursor.get_u32();
let name = cursor.read_string()?; let name = cursor.read_string()?;
let query = cursor.read_string()?; let query = cursor.read_string()?;
let num_params = cursor.get_i16(); let num_params = cursor.get_u16();
let mut param_types = Vec::new(); let mut param_types = Vec::new();
for _ in 0..num_params { for _ in 0..num_params {
param_types.push(cursor.get_i32()); param_types.push(cursor.get_i32());
} }
@@ -875,10 +874,10 @@ impl TryFrom<Parse> for BytesMut {
+ 4 * parse.num_params as usize; + 4 * parse.num_params as usize;
bytes.put_u8(parse.code as u8); bytes.put_u8(parse.code as u8);
bytes.put_i32(len as i32); bytes.put_u32(len as u32);
bytes.put_slice(name); bytes.put_slice(name);
bytes.put_slice(query); bytes.put_slice(query);
bytes.put_i16(parse.num_params); bytes.put_u16(parse.num_params);
for param in parse.param_types { for param in parse.param_types {
bytes.put_i32(param); bytes.put_i32(param);
} }
@@ -945,14 +944,14 @@ impl Parse {
pub struct Bind { pub struct Bind {
code: char, code: char,
#[allow(dead_code)] #[allow(dead_code)]
len: i64, len: u64,
portal: String, portal: String,
pub prepared_statement: String, pub prepared_statement: String,
num_param_format_codes: i16, num_param_format_codes: u16,
param_format_codes: Vec<i16>, param_format_codes: Vec<i16>,
num_param_values: i16, num_param_values: u16,
param_values: Vec<(i32, BytesMut)>, param_values: Vec<(i32, BytesMut)>,
num_result_column_format_codes: i16, num_result_column_format_codes: u16,
result_columns_format_codes: Vec<i16>, result_columns_format_codes: Vec<i16>,
} }
@@ -962,17 +961,17 @@ impl TryFrom<&BytesMut> for Bind {
fn try_from(buf: &BytesMut) -> Result<Bind, Error> { fn try_from(buf: &BytesMut) -> Result<Bind, Error> {
let mut cursor = Cursor::new(buf); let mut cursor = Cursor::new(buf);
let code = cursor.get_u8() as char; let code = cursor.get_u8() as char;
let len = cursor.get_i32(); let len = cursor.get_u32();
let portal = cursor.read_string()?; let portal = cursor.read_string()?;
let prepared_statement = cursor.read_string()?; let prepared_statement = cursor.read_string()?;
let num_param_format_codes = cursor.get_i16(); let num_param_format_codes = cursor.get_u16();
let mut param_format_codes = Vec::new(); let mut param_format_codes = Vec::new();
for _ in 0..num_param_format_codes { for _ in 0..num_param_format_codes {
param_format_codes.push(cursor.get_i16()); param_format_codes.push(cursor.get_i16());
} }
let num_param_values = cursor.get_i16(); let num_param_values = cursor.get_u16();
let mut param_values = Vec::new(); let mut param_values = Vec::new();
for _ in 0..num_param_values { for _ in 0..num_param_values {
@@ -994,7 +993,7 @@ impl TryFrom<&BytesMut> for Bind {
} }
} }
let num_result_column_format_codes = cursor.get_i16(); let num_result_column_format_codes = cursor.get_u16();
let mut result_columns_format_codes = Vec::new(); let mut result_columns_format_codes = Vec::new();
for _ in 0..num_result_column_format_codes { for _ in 0..num_result_column_format_codes {
@@ -1003,7 +1002,7 @@ impl TryFrom<&BytesMut> for Bind {
Ok(Bind { Ok(Bind {
code, code,
len: len as i64, len: len as u64,
portal, portal,
prepared_statement, prepared_statement,
num_param_format_codes, num_param_format_codes,
@@ -1042,19 +1041,19 @@ impl TryFrom<Bind> for BytesMut {
len += 2 * bind.num_result_column_format_codes as usize; len += 2 * bind.num_result_column_format_codes as usize;
bytes.put_u8(bind.code as u8); bytes.put_u8(bind.code as u8);
bytes.put_i32(len as i32); bytes.put_u32(len as u32);
bytes.put_slice(portal); bytes.put_slice(portal);
bytes.put_slice(prepared_statement); bytes.put_slice(prepared_statement);
bytes.put_i16(bind.num_param_format_codes); bytes.put_u16(bind.num_param_format_codes);
for param_format_code in bind.param_format_codes { for param_format_code in bind.param_format_codes {
bytes.put_i16(param_format_code); bytes.put_i16(param_format_code);
} }
bytes.put_i16(bind.num_param_values); bytes.put_u16(bind.num_param_values);
for (param_len, param) in bind.param_values { for (param_len, param) in bind.param_values {
bytes.put_i32(param_len); bytes.put_i32(param_len);
bytes.put_slice(&param); bytes.put_slice(&param);
} }
bytes.put_i16(bind.num_result_column_format_codes); bytes.put_u16(bind.num_result_column_format_codes);
for result_column_format_code in bind.result_columns_format_codes { for result_column_format_code in bind.result_columns_format_codes {
bytes.put_i16(result_column_format_code); bytes.put_i16(result_column_format_code);
} }
@@ -1068,7 +1067,7 @@ impl Bind {
pub fn get_name(buf: &BytesMut) -> Result<String, Error> { pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
let mut cursor = Cursor::new(buf); let mut cursor = Cursor::new(buf);
// Skip the code and length // Skip the code and length
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>()); cursor.advance(mem::size_of::<u8>() + mem::size_of::<u32>());
cursor.read_string()?; cursor.read_string()?;
cursor.read_string() cursor.read_string()
} }
@@ -1078,17 +1077,17 @@ impl Bind {
let mut cursor = Cursor::new(&buf); let mut cursor = Cursor::new(&buf);
// Read basic data from the cursor // Read basic data from the cursor
let code = cursor.get_u8(); let code = cursor.get_u8();
let current_len = cursor.get_i32(); let current_len = cursor.get_u32();
let portal = cursor.read_string()?; let portal = cursor.read_string()?;
let prepared_statement = cursor.read_string()?; let prepared_statement = cursor.read_string()?;
// Calculate new length // Calculate new length
let new_len = current_len + new_name.len() as i32 - prepared_statement.len() as i32; let new_len = current_len + new_name.len() as u32 - prepared_statement.len() as u32;
// Begin building the response buffer // Begin building the response buffer
let mut response_buf = BytesMut::with_capacity(new_len as usize + 1); let mut response_buf = BytesMut::with_capacity(new_len as usize + 1);
response_buf.put_u8(code); response_buf.put_u8(code);
response_buf.put_i32(new_len); response_buf.put_u32(new_len);
// Put the portal and new name into the buffer // Put the portal and new name into the buffer
// Note: panic if the provided string contains null byte // Note: panic if the provided string contains null byte
@@ -1112,7 +1111,7 @@ pub struct Describe {
code: char, code: char,
#[allow(dead_code)] #[allow(dead_code)]
len: i32, len: u32,
pub target: char, pub target: char,
pub statement_name: String, pub statement_name: String,
} }
@@ -1123,7 +1122,7 @@ impl TryFrom<&BytesMut> for Describe {
fn try_from(bytes: &BytesMut) -> Result<Describe, Error> { fn try_from(bytes: &BytesMut) -> Result<Describe, Error> {
let mut cursor = Cursor::new(bytes); let mut cursor = Cursor::new(bytes);
let code = cursor.get_u8() as char; let code = cursor.get_u8() as char;
let len = cursor.get_i32(); let len = cursor.get_u32();
let target = cursor.get_u8() as char; let target = cursor.get_u8() as char;
let statement_name = cursor.read_string()?; let statement_name = cursor.read_string()?;
@@ -1146,7 +1145,7 @@ impl TryFrom<Describe> for BytesMut {
let len = 4 + 1 + statement_name.len(); let len = 4 + 1 + statement_name.len();
bytes.put_u8(describe.code as u8); bytes.put_u8(describe.code as u8);
bytes.put_i32(len as i32); bytes.put_u32(len as u32);
bytes.put_u8(describe.target as u8); bytes.put_u8(describe.target as u8);
bytes.put_slice(statement_name); bytes.put_slice(statement_name);

View File

@@ -698,7 +698,6 @@ impl Server {
)) ))
} }
}; };
trace!("Error: {}", error_code); trace!("Error: {}", error_code);
match error_code { match error_code {
@@ -1013,6 +1012,12 @@ impl Server {
// which can leak between clients. This is a best effort to block bad clients // which can leak between clients. This is a best effort to block bad clients
// from poisoning a transaction-mode pool by setting inappropriate session variables // from poisoning a transaction-mode pool by setting inappropriate session variables
match command.as_str() { match command.as_str() {
"DISCARD ALL" => {
self.clear_prepared_statement_cache();
}
"DEALLOCATE ALL" => {
self.clear_prepared_statement_cache();
}
"SET" => { "SET" => {
// We don't detect set statements in transactions // We don't detect set statements in transactions
// No great way to differentiate between set and set local // No great way to differentiate between set and set local
@@ -1132,6 +1137,12 @@ impl Server {
has_it has_it
} }
fn clear_prepared_statement_cache(&mut self) {
if let Some(cache) = &mut self.prepared_statement_cache {
cache.clear();
}
}
fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option<String> { fn add_prepared_statement_to_cache(&mut self, name: &str) -> Option<String> {
let cache = match &mut self.prepared_statement_cache { let cache = match &mut self.prepared_statement_cache {
Some(cache) => cache, Some(cache) => cache,

View File

@@ -0,0 +1,145 @@
class PostgresMessage
# Base class for common functionality
def encode_string(str)
"#{str}\0" # Encode a string with a null terminator
end
def encode_int16(value)
[value].pack('n') # Encode an Int16
end
def encode_int32(value)
[value].pack('N') # Encode an Int32
end
def message_prefix(type, length)
"#{type}#{encode_int32(length)}" # Message type and length prefix
end
end
class SimpleQueryMessage < PostgresMessage
attr_accessor :query
def initialize(query = "")
@query = query
end
def to_bytes
query_bytes = encode_string(@query)
length = 4 + query_bytes.size # Length includes 4 bytes for length itself
message_prefix('Q', length) + query_bytes
end
end
class ParseMessage < PostgresMessage
attr_accessor :statement_name, :query, :parameter_types
def initialize(statement_name = "", query = "", parameter_types = [])
@statement_name = statement_name
@query = query
@parameter_types = parameter_types
end
def to_bytes
statement_name_bytes = encode_string(@statement_name)
query_bytes = encode_string(@query)
parameter_types_bytes = @parameter_types.pack('N*')
length = 4 + statement_name_bytes.size + query_bytes.size + 2 + parameter_types_bytes.size
message_prefix('P', length) + statement_name_bytes + query_bytes + encode_int16(@parameter_types.size) + parameter_types_bytes
end
end
class BindMessage < PostgresMessage
attr_accessor :portal_name, :statement_name, :parameter_format_codes, :parameters, :result_column_format_codes
def initialize(portal_name = "", statement_name = "", parameter_format_codes = [], parameters = [], result_column_format_codes = [])
@portal_name = portal_name
@statement_name = statement_name
@parameter_format_codes = parameter_format_codes
@parameters = parameters
@result_column_format_codes = result_column_format_codes
end
def to_bytes
portal_name_bytes = encode_string(@portal_name)
statement_name_bytes = encode_string(@statement_name)
parameter_format_codes_bytes = @parameter_format_codes.pack('n*')
parameters_bytes = @parameters.map do |param|
if param.nil?
encode_int32(-1)
else
encode_int32(param.bytesize) + param
end
end.join
result_column_format_codes_bytes = @result_column_format_codes.pack('n*')
length = 4 + portal_name_bytes.size + statement_name_bytes.size + 2 + parameter_format_codes_bytes.size + 2 + parameters_bytes.size + 2 + result_column_format_codes_bytes.size
message_prefix('B', length) + portal_name_bytes + statement_name_bytes + encode_int16(@parameter_format_codes.size) + parameter_format_codes_bytes + encode_int16(@parameters.size) + parameters_bytes + encode_int16(@result_column_format_codes.size) + result_column_format_codes_bytes
end
end
class DescribeMessage < PostgresMessage
attr_accessor :type, :name
def initialize(type = 'S', name = "")
@type = type
@name = name
end
def to_bytes
name_bytes = encode_string(@name)
length = 4 + 1 + name_bytes.size
message_prefix('D', length) + @type + name_bytes
end
end
class ExecuteMessage < PostgresMessage
attr_accessor :portal_name, :max_rows
def initialize(portal_name = "", max_rows = 0)
@portal_name = portal_name
@max_rows = max_rows
end
def to_bytes
portal_name_bytes = encode_string(@portal_name)
length = 4 + portal_name_bytes.size + 4
message_prefix('E', length) + portal_name_bytes + encode_int32(@max_rows)
end
end
class FlushMessage < PostgresMessage
def to_bytes
length = 4
message_prefix('H', length)
end
end
class SyncMessage < PostgresMessage
def to_bytes
length = 4
message_prefix('S', length)
end
end
class CloseMessage < PostgresMessage
attr_accessor :type, :name
def initialize(type = 'S', name = "")
@type = type
@name = name
end
def to_bytes
name_bytes = encode_string(@name)
length = 4 + 1 + name_bytes.size
message_prefix('C', length) + @type + name_bytes
end
end

View File

@@ -1,5 +1,6 @@
require 'socket' require 'socket'
require 'digest/md5' require 'digest/md5'
require_relative 'frontend_messages'
BACKEND_MESSAGE_CODES = { BACKEND_MESSAGE_CODES = {
'Z' => "ReadyForQuery", 'Z' => "ReadyForQuery",
@@ -18,9 +19,13 @@ class PostgresSocket
@host = host @host = host
@socket = TCPSocket.new @host, @port @socket = TCPSocket.new @host, @port
@parameters = {} @parameters = {}
@verbose = true @verbose = false
end end
def send_message(message)
@socket.write(message.to_bytes)
end
def send_md5_password_message(username, password, salt) def send_md5_password_message(username, password, salt)
m = Digest::MD5.hexdigest(password + username) m = Digest::MD5.hexdigest(password + username)
m = Digest::MD5.hexdigest(m + salt.map(&:chr).join("")) m = Digest::MD5.hexdigest(m + salt.map(&:chr).join(""))
@@ -113,107 +118,6 @@ class PostgresSocket
log "[F] Sent CancelRequest message" log "[F] Sent CancelRequest message"
end end
def send_query_message(query)
query_size = query.length
message_size = 1 + 4 + query_size
message = []
message << "Q".ord
message << [message_size].pack('l>').unpack('CCCC') # 4
message << query.split('').map(&:ord) # 2, 11
message << 0 # 1, 12
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent Q message (#{query})"
end
def send_parse_message(query)
query_size = query.length
message_size = 2 + 2 + 4 + query_size
message = []
message << "P".ord
message << [message_size].pack('l>').unpack('CCCC') # 4
message << 0 # unnamed statement
message << query.split('').map(&:ord) # 2, 11
message << 0 # 1, 12
message << [0, 0]
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent P message (#{query})"
end
def send_bind_message
message = []
message << "B".ord
message << [12].pack('l>').unpack('CCCC') # 4
message << 0 # unnamed statement
message << 0 # unnamed statement
message << [0, 0] # 2
message << [0, 0] # 2
message << [0, 0] # 2
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent B message"
end
def send_describe_message(mode)
message = []
message << "D".ord
message << [6].pack('l>').unpack('CCCC') # 4
message << mode.ord
message << 0 # unnamed statement
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent D message"
end
def send_execute_message(limit=0)
message = []
message << "E".ord
message << [9].pack('l>').unpack('CCCC') # 4
message << 0 # unnamed statement
message << [limit].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent E message"
end
def send_sync_message
message = []
message << "S".ord
message << [4].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent S message"
end
def send_copydone_message
message = []
message << "c".ord
message << [4].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent c message"
end
def send_copyfail_message
message = []
message << "f".ord
message << [5].pack('l>').unpack('CCCC') # 4
message << 0
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent f message"
end
def send_flush_message
message = []
message << "H".ord
message << [4].pack('l>').unpack('CCCC') # 4
message.flatten!
@socket.write(message.flatten.pack('C*'))
log "[F] Sent H message"
end
def read_from_server() def read_from_server()
output_messages = [] output_messages = []
retry_count = 0 retry_count = 0

View File

@@ -16,10 +16,14 @@ describe "Portocol handling" do
end end
def run_comparison(sequence, socket_a, socket_b) def run_comparison(sequence, socket_a, socket_b)
sequence.each do |msg, *args| sequence.each do |msg|
socket_a.send(msg, *args) if msg.is_a?(Symbol)
socket_b.send(msg, *args) socket_a.send(msg)
socket_b.send(msg)
else
socket_a.send_message(msg)
socket_b.send_message(msg)
end
compare_messages( compare_messages(
socket_a.read_from_server, socket_a.read_from_server,
socket_b.read_from_server socket_b.read_from_server
@@ -83,9 +87,9 @@ describe "Portocol handling" do
context "Cancel Query" do context "Cancel Query" do
let(:sequence) { let(:sequence) {
[ [
[:send_query_message, "SELECT pg_sleep(5)"], SimpleQueryMessage.new("SELECT pg_sleep(5)"),
[:cancel_query] :cancel_query
] ]
} }
@@ -95,12 +99,12 @@ describe "Portocol handling" do
xcontext "Simple query after parse" do xcontext "Simple query after parse" do
let(:sequence) { let(:sequence) {
[ [
[:send_parse_message, "SELECT 5"], ParseMessage.new("", "SELECT 5", []),
[:send_query_message, "SELECT 1"], SimpleQueryMessage.new("SELECT 1"),
[:send_bind_message], BindMessage.new("", "", [], [], [0]),
[:send_describe_message, "P"], DescribeMessage.new("P", ""),
[:send_execute_message], ExecuteMessage.new("", 1),
[:send_sync_message], SyncMessage.new
] ]
} }
@@ -111,8 +115,8 @@ describe "Portocol handling" do
xcontext "Flush message" do xcontext "Flush message" do
let(:sequence) { let(:sequence) {
[ [
[:send_parse_message, "SELECT 1"], ParseMessage.new("", "SELECT 1", []),
[:send_flush_message] FlushMessage.new
] ]
} }
@@ -122,9 +126,7 @@ describe "Portocol handling" do
xcontext "Bind without parse" do xcontext "Bind without parse" do
let(:sequence) { let(:sequence) {
[ [BindMessage.new("", "", [], [], [0])]
[:send_bind_message]
]
} }
# This is known to fail. # This is known to fail.
# Server responds immediately, Proxy buffers the message # Server responds immediately, Proxy buffers the message
@@ -133,23 +135,155 @@ describe "Portocol handling" do
context "Simple message" do context "Simple message" do
let(:sequence) { let(:sequence) {
[[:send_query_message, "SELECT 1"]] [SimpleQueryMessage.new("SELECT 1")]
} }
it_behaves_like "at parity with database" it_behaves_like "at parity with database"
end end
10.times do |i|
context "Extended protocol" do
let(:sequence) {
[
ParseMessage.new("", "SELECT 1", []),
BindMessage.new("", "", [], [], [0]),
DescribeMessage.new("S", ""),
ExecuteMessage.new("", 1),
SyncMessage.new
]
}
context "Extended protocol" do it_behaves_like "at parity with database"
let(:sequence) { end
[
[:send_parse_message, "SELECT 1"],
[:send_bind_message],
[:send_describe_message, "P"],
[:send_execute_message],
[:send_sync_message],
]
}
it_behaves_like "at parity with database"
end end
end
describe "Protocol-level prepared statements" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 1, "transaction") }
before do
q_sock = PostgresSocket.new('localhost', processes.pgcat.port)
q_sock.send_startup_message("sharding_user", "sharded_db", "sharding_user")
table_query = "CREATE TABLE IF NOT EXISTS employees (employee_id SERIAL PRIMARY KEY, salary NUMERIC(10, 2) CHECK (salary > 0));"
q_sock.send_message(SimpleQueryMessage.new(table_query))
q_sock.close
current_configs = processes.pgcat.current_config
current_configs["pools"]["sharded_db"]["prepared_statements_cache_size"] = 500
processes.pgcat.update_config(current_configs)
processes.pgcat.reload_config
end
after do
q_sock = PostgresSocket.new('localhost', processes.pgcat.port)
q_sock.send_startup_message("sharding_user", "sharded_db", "sharding_user")
table_query = "DROP TABLE IF EXISTS employees;"
q_sock.send_message(SimpleQueryMessage.new(table_query))
q_sock.close
end
context "When unnamed prepared statements are used" do
it "does not cache them" do
socket = PostgresSocket.new('localhost', processes.pgcat.port)
socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
socket.send_message(SimpleQueryMessage.new("DISCARD ALL"))
socket.read_from_server
10.times do |i|
socket.send_message(ParseMessage.new("", "SELECT #{i}", []))
socket.send_message(BindMessage.new("", "", [], [], [0]))
socket.send_message(DescribeMessage.new("S", ""))
socket.send_message(ExecuteMessage.new("", 1))
socket.send_message(SyncMessage.new)
socket.read_from_server
end
socket.send_message(SimpleQueryMessage.new("SELECT name, statement, prepare_time, parameter_types FROM pg_prepared_statements"))
result = socket.read_from_server
number_of_saved_statements = result.count { |m| m[:code] == 'D' }
expect(number_of_saved_statements).to eq(0)
end
end
context "When named prepared statements are used" do
it "caches them" do
socket = PostgresSocket.new('localhost', processes.pgcat.port)
socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
socket.send_message(SimpleQueryMessage.new("DISCARD ALL"))
socket.read_from_server
3.times do
socket.send_message(ParseMessage.new("my_query", "SELECT * FROM employees WHERE employee_id in ($1,$2,$3)", [0,0,0]))
socket.send_message(BindMessage.new("", "my_query", [0,0,0], [0,0,0].map(&:to_s), [0,0,0,0,0,0]))
socket.send_message(SyncMessage.new)
socket.read_from_server
end
3.times do
socket.send_message(ParseMessage.new("my_other_query", "SELECT * FROM employees WHERE salary in ($1,$2,$3)", [0,0,0]))
socket.send_message(BindMessage.new("", "my_other_query", [0,0,0], [0,0,0].map(&:to_s), [0,0,0,0,0,0]))
socket.send_message(SyncMessage.new)
socket.read_from_server
end
socket.send_message(SimpleQueryMessage.new("SELECT name, statement, prepare_time, parameter_types FROM pg_prepared_statements"))
result = socket.read_from_server
number_of_saved_statements = result.count { |m| m[:code] == 'D' }
expect(number_of_saved_statements).to eq(2)
end
end
context "When DISCARD ALL/DEALLOCATE ALL are called" do
it "resets server and client caches" do
socket = PostgresSocket.new('localhost', processes.pgcat.port)
socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
20.times do |i|
socket.send_message(ParseMessage.new("my_query_#{i}", "SELECT * FROM employees WHERE employee_id in ($1,$2,$3)", [0,0,0]))
end
20.times do |i|
socket.send_message(BindMessage.new("", "my_query_#{i}", [0,0,0], [0,0,0].map(&:to_s), [0,0]))
end
socket.send_message(SyncMessage.new)
socket.read_from_server
socket.send_message(SimpleQueryMessage.new("DISCARD ALL"))
socket.read_from_server
responses = []
4.times do |i|
socket.send_message(ParseMessage.new("my_query_#{i}", "SELECT * FROM employees WHERE employee_id in ($1,$2,$3)", [0,0,0]))
socket.send_message(BindMessage.new("", "my_query_#{i}", [0,0,0], [0,0,0].map(&:to_s), [0,0]))
socket.send_message(SyncMessage.new)
responses += socket.read_from_server
end
errors = responses.select { |message| message[:code] == 'E' }
error_message = errors.map { |message| message[:bytes].map(&:chr).join("") }.join("\n")
raise StandardError, "Encountered the following errors: #{error_message}" if errors.length > 0
end
end
context "Maximum number of bound paramters" do
it "does not crash" do
test_socket = PostgresSocket.new('localhost', processes.pgcat.port)
test_socket.send_startup_message("sharding_user", "sharded_db", "sharding_user")
types = Array.new(65_535) { |i| 0 }
params = Array.new(65_535) { |i| "$#{i+1}" }.join(",")
test_socket.send_message(ParseMessage.new("my_query", "SELECT * FROM employees WHERE employee_id in (#{params})", types))
test_socket.send_message(BindMessage.new("my_query", "my_query", types, types.map(&:to_s), types))
test_socket.send_message(SyncMessage.new)
# If the proxy crashes, this will raise an error
expect { test_socket.read_from_server }.to_not raise_error
test_socket.close
end
end
end
end