From 7c55bf78feaa7686f7a0c2646e2337fcfab4cece Mon Sep 17 00:00:00 2001 From: Mostafa Date: Sun, 1 Sep 2024 14:39:05 -0500 Subject: [PATCH] Add failing tests --- tests/ruby/helpers/frontend_messages.rb | 145 +++++++++++++++++ tests/ruby/helpers/pg_socket.rb | 108 +------------ tests/ruby/protocol_spec.rb | 198 ++++++++++++++++++++---- 3 files changed, 317 insertions(+), 134 deletions(-) create mode 100644 tests/ruby/helpers/frontend_messages.rb diff --git a/tests/ruby/helpers/frontend_messages.rb b/tests/ruby/helpers/frontend_messages.rb new file mode 100644 index 0000000..1245cf9 --- /dev/null +++ b/tests/ruby/helpers/frontend_messages.rb @@ -0,0 +1,145 @@ + +class PostgresMessage + # Base class for common functionality + + def encode_string(str) + "#{str}\0" # Encode a string with a null terminator + end + + def encode_int16(value) + [value].pack('n') # Encode an Int16 + end + + def encode_int32(value) + [value].pack('N') # Encode an Int32 + end + + def message_prefix(type, length) + "#{type}#{encode_int32(length)}" # Message type and length prefix + end +end + +class SimpleQueryMessage < PostgresMessage + attr_accessor :query + + def initialize(query = "") + @query = query + end + + def to_bytes + query_bytes = encode_string(@query) + length = 4 + query_bytes.size # Length includes 4 bytes for length itself + message_prefix('Q', length) + query_bytes + end +end + +class ParseMessage < PostgresMessage + attr_accessor :statement_name, :query, :parameter_types + + def initialize(statement_name = "", query = "", parameter_types = []) + @statement_name = statement_name + @query = query + @parameter_types = parameter_types + end + + def to_bytes + statement_name_bytes = encode_string(@statement_name) + query_bytes = encode_string(@query) + parameter_types_bytes = @parameter_types.pack('N*') + + length = 4 + statement_name_bytes.size + query_bytes.size + 2 + parameter_types_bytes.size + message_prefix('P', length) + statement_name_bytes + query_bytes + encode_int16(@parameter_types.size) + parameter_types_bytes + end +end + +class BindMessage < PostgresMessage + attr_accessor :portal_name, :statement_name, :parameter_format_codes, :parameters, :result_column_format_codes + + def initialize(portal_name = "", statement_name = "", parameter_format_codes = [], parameters = [], result_column_format_codes = []) + @portal_name = portal_name + @statement_name = statement_name + @parameter_format_codes = parameter_format_codes + @parameters = parameters + @result_column_format_codes = result_column_format_codes + end + + def to_bytes + portal_name_bytes = encode_string(@portal_name) + statement_name_bytes = encode_string(@statement_name) + parameter_format_codes_bytes = @parameter_format_codes.pack('n*') + + parameters_bytes = @parameters.map do |param| + if param.nil? + encode_int32(-1) + else + encode_int32(param.bytesize) + param + end + end.join + + result_column_format_codes_bytes = @result_column_format_codes.pack('n*') + + length = 4 + portal_name_bytes.size + statement_name_bytes.size + 2 + parameter_format_codes_bytes.size + 2 + parameters_bytes.size + 2 + result_column_format_codes_bytes.size + message_prefix('B', length) + portal_name_bytes + statement_name_bytes + encode_int16(@parameter_format_codes.size) + parameter_format_codes_bytes + encode_int16(@parameters.size) + parameters_bytes + encode_int16(@result_column_format_codes.size) + result_column_format_codes_bytes + end +end + +class DescribeMessage < PostgresMessage + attr_accessor :type, :name + + def initialize(type = 'S', name = "") + @type = type + @name = name + end + + def to_bytes + name_bytes = encode_string(@name) + length = 4 + 1 + name_bytes.size + message_prefix('D', length) + @type + name_bytes + end +end + + +class ExecuteMessage < PostgresMessage + attr_accessor :portal_name, :max_rows + + def initialize(portal_name = "", max_rows = 0) + @portal_name = portal_name + @max_rows = max_rows + end + + def to_bytes + portal_name_bytes = encode_string(@portal_name) + length = 4 + portal_name_bytes.size + 4 + message_prefix('E', length) + portal_name_bytes + encode_int32(@max_rows) + end +end + +class FlushMessage < PostgresMessage + def to_bytes + length = 4 + message_prefix('H', length) + end +end + +class SyncMessage < PostgresMessage + def to_bytes + length = 4 + message_prefix('S', length) + end +end + +class CloseMessage < PostgresMessage + attr_accessor :type, :name + + def initialize(type = 'S', name = "") + @type = type + @name = name + end + + def to_bytes + name_bytes = encode_string(@name) + length = 4 + 1 + name_bytes.size + message_prefix('C', length) + @type + name_bytes + end +end + diff --git a/tests/ruby/helpers/pg_socket.rb b/tests/ruby/helpers/pg_socket.rb index 4223449..e4fbb30 100644 --- a/tests/ruby/helpers/pg_socket.rb +++ b/tests/ruby/helpers/pg_socket.rb @@ -1,5 +1,6 @@ require 'socket' require 'digest/md5' +require_relative 'frontend_messages' BACKEND_MESSAGE_CODES = { 'Z' => "ReadyForQuery", @@ -18,9 +19,13 @@ class PostgresSocket @host = host @socket = TCPSocket.new @host, @port @parameters = {} - @verbose = true + @verbose = false end + def send_message(message) + @socket.write(message.to_bytes) + end + def send_md5_password_message(username, password, salt) m = Digest::MD5.hexdigest(password + username) m = Digest::MD5.hexdigest(m + salt.map(&:chr).join("")) @@ -113,107 +118,6 @@ class PostgresSocket log "[F] Sent CancelRequest message" end - def send_query_message(query) - query_size = query.length - message_size = 1 + 4 + query_size - message = [] - message << "Q".ord - message << [message_size].pack('l>').unpack('CCCC') # 4 - message << query.split('').map(&:ord) # 2, 11 - message << 0 # 1, 12 - message.flatten! - @socket.write(message.flatten.pack('C*')) - log "[F] Sent Q message (#{query})" - end - - def send_parse_message(query) - query_size = query.length - message_size = 2 + 2 + 4 + query_size - message = [] - message << "P".ord - message << [message_size].pack('l>').unpack('CCCC') # 4 - message << 0 # unnamed statement - message << query.split('').map(&:ord) # 2, 11 - message << 0 # 1, 12 - message << [0, 0] - message.flatten! - @socket.write(message.flatten.pack('C*')) - log "[F] Sent P message (#{query})" - end - - def send_bind_message - message = [] - message << "B".ord - message << [12].pack('l>').unpack('CCCC') # 4 - message << 0 # unnamed statement - message << 0 # unnamed statement - message << [0, 0] # 2 - message << [0, 0] # 2 - message << [0, 0] # 2 - message.flatten! - @socket.write(message.flatten.pack('C*')) - log "[F] Sent B message" - end - - def send_describe_message(mode) - message = [] - message << "D".ord - message << [6].pack('l>').unpack('CCCC') # 4 - message << mode.ord - message << 0 # unnamed statement - message.flatten! - @socket.write(message.flatten.pack('C*')) - log "[F] Sent D message" - end - - def send_execute_message(limit=0) - message = [] - message << "E".ord - message << [9].pack('l>').unpack('CCCC') # 4 - message << 0 # unnamed statement - message << [limit].pack('l>').unpack('CCCC') # 4 - message.flatten! - @socket.write(message.flatten.pack('C*')) - log "[F] Sent E message" - end - - def send_sync_message - message = [] - message << "S".ord - message << [4].pack('l>').unpack('CCCC') # 4 - message.flatten! - @socket.write(message.flatten.pack('C*')) - log "[F] Sent S message" - end - - def send_copydone_message - message = [] - message << "c".ord - message << [4].pack('l>').unpack('CCCC') # 4 - message.flatten! - @socket.write(message.flatten.pack('C*')) - log "[F] Sent c message" - end - - def send_copyfail_message - message = [] - message << "f".ord - message << [5].pack('l>').unpack('CCCC') # 4 - message << 0 - message.flatten! - @socket.write(message.flatten.pack('C*')) - log "[F] Sent f message" - end - - def send_flush_message - message = [] - message << "H".ord - message << [4].pack('l>').unpack('CCCC') # 4 - message.flatten! - @socket.write(message.flatten.pack('C*')) - log "[F] Sent H message" - end - def read_from_server() output_messages = [] retry_count = 0 diff --git a/tests/ruby/protocol_spec.rb b/tests/ruby/protocol_spec.rb index 9737650..daa83ec 100644 --- a/tests/ruby/protocol_spec.rb +++ b/tests/ruby/protocol_spec.rb @@ -16,10 +16,14 @@ describe "Portocol handling" do end def run_comparison(sequence, socket_a, socket_b) - sequence.each do |msg, *args| - socket_a.send(msg, *args) - socket_b.send(msg, *args) - + sequence.each do |msg| + if msg.is_a?(Symbol) + socket_a.send(msg) + socket_b.send(msg) + else + socket_a.send_message(msg) + socket_b.send_message(msg) + end compare_messages( socket_a.read_from_server, socket_b.read_from_server @@ -83,9 +87,9 @@ describe "Portocol handling" do context "Cancel Query" do let(:sequence) { - [ - [:send_query_message, "SELECT pg_sleep(5)"], - [:cancel_query] + [ + SimpleQueryMessage.new("SELECT pg_sleep(5)"), + :cancel_query ] } @@ -95,12 +99,12 @@ describe "Portocol handling" do xcontext "Simple query after parse" do let(:sequence) { [ - [:send_parse_message, "SELECT 5"], - [:send_query_message, "SELECT 1"], - [:send_bind_message], - [:send_describe_message, "P"], - [:send_execute_message], - [:send_sync_message], + ParseMessage.new("", "SELECT 5", []), + SimpleQueryMessage.new("SELECT 1"), + BindMessage.new("", "", [], [], [0]), + DescribeMessage.new("P", ""), + ExecuteMessage.new("", 1), + SyncMessage.new ] } @@ -111,8 +115,8 @@ describe "Portocol handling" do xcontext "Flush message" do let(:sequence) { [ - [:send_parse_message, "SELECT 1"], - [:send_flush_message] + ParseMessage.new("", "SELECT 1", []), + FlushMessage.new ] } @@ -122,9 +126,7 @@ describe "Portocol handling" do xcontext "Bind without parse" do let(:sequence) { - [ - [:send_bind_message] - ] + [BindMessage.new("", "", [], [], [0])] } # This is known to fail. # Server responds immediately, Proxy buffers the message @@ -133,23 +135,155 @@ describe "Portocol handling" do context "Simple message" do let(:sequence) { - [[:send_query_message, "SELECT 1"]] + [SimpleQueryMessage.new("SELECT 1")] } it_behaves_like "at parity with database" end + + 10.times do |i| + context "Extended protocol" do + let(:sequence) { + [ + ParseMessage.new("", "SELECT 1", []), + BindMessage.new("", "", [], [], [0]), + DescribeMessage.new("S", ""), + ExecuteMessage.new("", 1), + SyncMessage.new + ] + } - context "Extended protocol" do - let(:sequence) { - [ - [:send_parse_message, "SELECT 1"], - [:send_bind_message], - [:send_describe_message, "P"], - [:send_execute_message], - [:send_sync_message], - ] - } - - it_behaves_like "at parity with database" + it_behaves_like "at parity with database" + end end -end + + describe "Protocol-level prepared statements" do + let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 1, "transaction") } + before do + q_sock = PostgresSocket.new('localhost', processes.pgcat.port) + q_sock.send_startup_message("sharding_user", "sharded_db", "sharding_user") + table_query = "CREATE TABLE IF NOT EXISTS employees (employee_id SERIAL PRIMARY KEY, salary NUMERIC(10, 2) CHECK (salary > 0));" + q_sock.send_message(SimpleQueryMessage.new(table_query)) + q_sock.close + + current_configs = processes.pgcat.current_config + current_configs["pools"]["sharded_db"]["prepared_statements_cache_size"] = 500 + processes.pgcat.update_config(current_configs) + processes.pgcat.reload_config + end + after do + q_sock = PostgresSocket.new('localhost', processes.pgcat.port) + q_sock.send_startup_message("sharding_user", "sharded_db", "sharding_user") + table_query = "DROP TABLE IF EXISTS employees;" + q_sock.send_message(SimpleQueryMessage.new(table_query)) + q_sock.close + end + + context "When unnamed prepared statements are used" do + it "does not cache them" do + socket = PostgresSocket.new('localhost', processes.pgcat.port) + socket.send_startup_message("sharding_user", "sharded_db", "sharding_user") + + socket.send_message(SimpleQueryMessage.new("DISCARD ALL")) + socket.read_from_server + + 10.times do |i| + socket.send_message(ParseMessage.new("", "SELECT #{i}", [])) + socket.send_message(BindMessage.new("", "", [], [], [0])) + socket.send_message(DescribeMessage.new("S", "")) + socket.send_message(ExecuteMessage.new("", 1)) + socket.send_message(SyncMessage.new) + socket.read_from_server + end + + socket.send_message(SimpleQueryMessage.new("SELECT name, statement, prepare_time, parameter_types FROM pg_prepared_statements")) + result = socket.read_from_server + number_of_saved_statements = result.count { |m| m[:code] == 'D' } + expect(number_of_saved_statements).to eq(0) + end + end + + context "When named prepared statements are used" do + it "caches them" do + socket = PostgresSocket.new('localhost', processes.pgcat.port) + socket.send_startup_message("sharding_user", "sharded_db", "sharding_user") + + socket.send_message(SimpleQueryMessage.new("DISCARD ALL")) + socket.read_from_server + + 3.times do + socket.send_message(ParseMessage.new("my_query", "SELECT * FROM employees WHERE employee_id in ($1,$2,$3)", [0,0,0])) + socket.send_message(BindMessage.new("", "my_query", [0,0,0], [0,0,0].map(&:to_s), [0,0,0,0,0,0])) + socket.send_message(SyncMessage.new) + socket.read_from_server + end + + 3.times do + socket.send_message(ParseMessage.new("my_other_query", "SELECT * FROM employees WHERE salary in ($1,$2,$3)", [0,0,0])) + socket.send_message(BindMessage.new("", "my_other_query", [0,0,0], [0,0,0].map(&:to_s), [0,0,0,0,0,0])) + socket.send_message(SyncMessage.new) + socket.read_from_server + end + + socket.send_message(SimpleQueryMessage.new("SELECT name, statement, prepare_time, parameter_types FROM pg_prepared_statements")) + result = socket.read_from_server + number_of_saved_statements = result.count { |m| m[:code] == 'D' } + expect(number_of_saved_statements).to eq(2) + end + end + + context "When DISCARD ALL/DEALLOCATE ALL are called" do + it "resets server and client caches" do + socket = PostgresSocket.new('localhost', processes.pgcat.port) + socket.send_startup_message("sharding_user", "sharded_db", "sharding_user") + + 20.times do |i| + socket.send_message(ParseMessage.new("my_query_#{i}", "SELECT * FROM employees WHERE employee_id in ($1,$2,$3)", [0,0,0])) + end + + 20.times do |i| + socket.send_message(BindMessage.new("", "my_query_#{i}", [0,0,0], [0,0,0].map(&:to_s), [0,0,0,0,0,0])) + end + + socket.send_message(SyncMessage.new) + socket.read_from_server + + socket.send_message(SimpleQueryMessage.new("DISCARD ALL")) + socket.read_from_server + responses = [] + 4.times do |i| + socket.send_message(ParseMessage.new("my_query_#{i}", "SELECT * FROM employees WHERE employee_id in ($1,$2,$3)", [0,0,0])) + socket.send_message(BindMessage.new("", "my_query_#{i}", [0,0,0], [0,0,0].map(&:to_s), [0,0,0,0,0,0])) + socket.send_message(SyncMessage.new) + + responses += socket.read_from_server + end + + errors = responses.select { |message| message[:code] == 'E' } + error_message = errors.map { |message| message[:bytes].map(&:chr).join("") }.join("\n") + raise StandardError, "Encountered the following errors: #{error_message}" if errors.length > 0 + end + end + + context "Maximum number of bound paramters" do + it "does not crash" do + test_socket = PostgresSocket.new('localhost', processes.pgcat.port) + test_socket.send_startup_message("sharding_user", "sharded_db", "sharding_user") + + types = Array.new(65_535) { |i| 0 } + + params = Array.new(65_535) { |i| "$#{i+1}" }.join(",") + test_socket.send_message(ParseMessage.new("my_query", "SELECT * FROM employees WHERE employee_id in (#{params})", types)) + + test_socket.send_message(BindMessage.new("my_query", "my_query", types, types.map(&:to_s), types)) + + test_socket.send_message(SyncMessage.new) + + # If the proxy crashes, this will raise an error + expect { test_socket.read_from_server }.to_not raise_error + + test_socket.close + end + end + end +end \ No newline at end of file