diff --git a/.dockerignore b/.dockerignore index ff42aaa..1063ad1 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,3 +2,5 @@ target/ tests/ tracing/ .circleci/ +.git/ +dev/ diff --git a/Dockerfile.ci b/Dockerfile.ci index 42d213a..4503e87 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -1,7 +1,7 @@ -FROM cimg/rust:1.62.0 +FROM cimg/rust:1.67.1 RUN sudo apt-get update && \ sudo apt-get install -y \ - psmisc postgresql-contrib-12 postgresql-client-12 libpq-dev \ + psmisc postgresql-contrib-14 postgresql-client-14 libpq-dev \ ruby ruby-dev python3 python3-pip \ lcov llvm-11 iproute2 && \ sudo apt-get upgrade curl && \ diff --git a/dev/script/console b/dev/script/console index f2a12d6..27715d1 100755 --- a/dev/script/console +++ b/dev/script/console @@ -3,4 +3,10 @@ DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" export HOST_UID="$(id -u)" export HOST_GID="$(id -g)" -docker-compose -f "${DIR}/../docker-compose.yaml" run --rm pgcat-shell + +if [[ "${1}" == "down" ]]; then + docker-compose -f "${DIR}/../docker-compose.yaml" down + exit 0 +else + docker-compose -f "${DIR}/../docker-compose.yaml" run --rm pgcat-shell +fi diff --git a/src/client.rs b/src/client.rs index 2e04c2a..bb70a65 100644 --- a/src/client.rs +++ b/src/client.rs @@ -675,14 +675,42 @@ where // allocate a connection, we wouldn't be able to send back an error message // to the client so we buffer them and defer the decision to error out or not // to when we get the S message - 'P' | 'B' | 'D' | 'E' => { + 'D' | 'E' => { self.buffer.put(&message[..]); continue; } + + 'Q' => { + if query_router.query_parser_enabled() { + query_router.infer(&message); + } + } + + 'P' => { + self.buffer.put(&message[..]); + + if query_router.query_parser_enabled() { + query_router.infer(&message); + } + + continue; + } + + 'B' => { + self.buffer.put(&message[..]); + + if query_router.query_parser_enabled() { + query_router.infer_shard_from_bind(&message); + } + + continue; + } + 'X' => { debug!("Client disconnecting"); return Ok(()); } + _ => (), } @@ -711,11 +739,7 @@ where // Handle all custom protocol commands, if any. match query_router.try_execute_command(&message) { // Normal query, not a custom command. - None => { - if query_router.query_parser_enabled() { - query_router.infer(&message); - } - } + None => (), // SET SHARD TO Some((Command::SetShard, _)) => { @@ -727,7 +751,7 @@ where error_response( &mut self.write, &format!( - "shard {} is more than configured {}, staying on shard {}", + "shard {} is more than configured {}, staying on shard {} (shard numbers start at 0)", query_router.shard(), pool.shards(), current_shard, diff --git a/src/query_router.rs b/src/query_router.rs index fbff68e..578c739 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -43,6 +43,20 @@ pub enum Command { ShowPrimaryReads, } +#[derive(PartialEq, Debug)] +pub enum ShardingKey { + Value(i64), + Placeholder(i16), +} + +#[derive(Clone, Debug)] +enum ParameterFormat { + Text, + Binary, + Uniform(Box), + Specified(Vec), +} + /// Quickly test for match when a query is received. static CUSTOM_SQL_REGEX_SET: OnceCell = OnceCell::new(); @@ -65,6 +79,9 @@ pub struct QueryRouter { /// Pool configuration. pool_settings: PoolSettings, + + // Placeholders from prepared statement. + placeholders: Vec, } impl QueryRouter { @@ -103,6 +120,7 @@ impl QueryRouter { query_parser_enabled: None, primary_reads_enabled: None, pool_settings: PoolSettings::default(), + placeholders: Vec::new(), } } @@ -307,10 +325,10 @@ impl QueryRouter { } /// Try to infer which server to connect to based on the contents of the query. - pub fn infer(&mut self, message_buffer: &BytesMut) -> bool { + pub fn infer(&mut self, message: &BytesMut) -> bool { debug!("Inferring role"); - let mut message_cursor = Cursor::new(message_buffer); + let mut message_cursor = Cursor::new(message); let code = message_cursor.get_u8() as char; let _len = message_cursor.get_i32() as usize; @@ -332,8 +350,7 @@ impl QueryRouter { let query = message_cursor.read_string().unwrap(); debug!("Prepared statement: '{}'", query); - - query.replace('$', "") // Remove placeholders turning them into "values" + query } _ => return false, @@ -343,7 +360,7 @@ impl QueryRouter { Ok(ast) => ast, Err(err) => { // SELECT ... FOR UPDATE won't get parsed correctly. - error!("{}: {}", err, query); + debug!("{}: {}", err, query); self.active_role = Some(Role::Primary); return false; } @@ -404,9 +421,147 @@ impl QueryRouter { true } + /// Parse the shard number from the Bind message + /// which contains the arguments for a prepared statement. + /// + /// N.B.: Only supports anonymous prepared statements since we don't + /// keep a cache of them in PgCat. + pub fn infer_shard_from_bind(&mut self, message: &BytesMut) -> bool { + debug!("Parsing bind message"); + + let mut message_cursor = Cursor::new(message); + + let code = message_cursor.get_u8() as char; + let len = message_cursor.get_i32(); + + if code != 'B' { + debug!("Not a bind packet"); + return false; + } + + // Check message length + if message.len() != len as usize + 1 { + debug!( + "Message has wrong length, expected {}, but have {}", + len, + message.len() + ); + return false; + } + + // There are no shard keys in the prepared statement. + if self.placeholders.is_empty() { + debug!("There are no placeholders in the prepared statement that matched the automatic sharding key"); + return false; + } + + let sharder = Sharder::new( + self.pool_settings.shards, + self.pool_settings.sharding_function, + ); + + let mut shards = BTreeSet::new(); + + let _portal = message_cursor.read_string(); + let _name = message_cursor.read_string(); + + let num_params = message_cursor.get_i16(); + let parameter_format = match num_params { + 0 => ParameterFormat::Text, // Text + 1 => { + let param_format = message_cursor.get_i16(); + ParameterFormat::Uniform(match param_format { + 0 => Box::new(ParameterFormat::Text), + 1 => Box::new(ParameterFormat::Binary), + _ => unreachable!(), + }) + } + n => { + let mut v = Vec::with_capacity(n as usize); + for _ in 0..n { + let param_format = message_cursor.get_i16(); + v.push(match param_format { + 0 => ParameterFormat::Text, + 1 => ParameterFormat::Binary, + _ => unreachable!(), + }); + } + ParameterFormat::Specified(v) + } + }; + + let num_parameters = message_cursor.get_i16(); + + for i in 0..num_parameters { + let mut len = message_cursor.get_i32() as usize; + let format = match ¶meter_format { + ParameterFormat::Text => ParameterFormat::Text, + ParameterFormat::Uniform(format) => *format.clone(), + ParameterFormat::Specified(formats) => formats[i as usize].clone(), + _ => unreachable!(), + }; + + debug!("Parameter {} (len: {}): {:?}", i, len, format); + + // Postgres counts placeholders starting at 1 + let placeholder = i + 1; + + if self.placeholders.contains(&placeholder) { + let value = match format { + ParameterFormat::Text => { + let mut value = String::new(); + while len > 0 { + value.push(message_cursor.get_u8() as char); + len -= 1; + } + + match value.parse::() { + Ok(value) => value, + Err(_) => { + debug!("Error parsing bind value: {}", value); + continue; + } + } + } + + ParameterFormat::Binary => match len { + 2 => message_cursor.get_i16() as i64, + 4 => message_cursor.get_i32() as i64, + 8 => message_cursor.get_i64(), + _ => { + error!( + "Got wrong length for integer type parameter in bind: {}", + len + ); + continue; + } + }, + + _ => unreachable!(), + }; + + shards.insert(sharder.shard(value)); + } + } + + self.placeholders.clear(); + self.placeholders.shrink_to_fit(); + + // We only support querying one shard at a time. + // TODO: Support multi-shard queries some day. + if shards.len() == 1 { + debug!("Found one sharding key"); + self.set_shard(*shards.first().unwrap()); + true + } else { + debug!("Found no sharding keys"); + false + } + } + /// A `selection` is the `WHERE` clause. This parses /// the clause and extracts the sharding key, if present. - fn selection_parser(&self, expr: &Expr, table_names: &Vec>) -> Vec { + fn selection_parser(&self, expr: &Expr, table_names: &Vec>) -> Vec { let mut result = Vec::new(); let mut found = false; @@ -487,13 +642,25 @@ impl QueryRouter { Expr::Value(Value::Number(value, ..)) => { if found { match value.parse::() { - Ok(value) => result.push(value), + Ok(value) => result.push(ShardingKey::Value(value)), Err(_) => { debug!("Sharding key was not an integer: {}", value); } }; } } + + Expr::Value(Value::Placeholder(placeholder)) => { + match placeholder.replace("$", "").parse::() { + Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)), + Err(_) => { + debug!( + "Prepared statement didn't have integer placeholders: {}", + placeholder + ); + } + } + } _ => (), }; } @@ -504,7 +671,7 @@ impl QueryRouter { } /// Try to figure out which shard the query should go to. - fn infer_shard(&self, query: &sqlparser::ast::Query) -> Option { + fn infer_shard(&mut self, query: &sqlparser::ast::Query) -> Option { let mut shards = BTreeSet::new(); let mut exprs = Vec::new(); @@ -569,6 +736,11 @@ impl QueryRouter { None => (), }; + let sharder = Sharder::new( + self.pool_settings.shards, + self.pool_settings.sharding_function, + ); + // Look for sharding keys in either the join condition // or the selection. for expr in exprs.iter() { @@ -577,14 +749,17 @@ impl QueryRouter { // TODO: Add support for prepared statements here. // This should just give us the position of the value in the `B` message. - let sharder = Sharder::new( - self.pool_settings.shards, - self.pool_settings.sharding_function, - ); - for value in sharding_keys { - let shard = sharder.shard(value); - shards.insert(shard); + match value { + ShardingKey::Value(value) => { + let shard = sharder.shard(value); + shards.insert(shard); + } + + ShardingKey::Placeholder(position) => { + self.placeholders.push(position); + } + }; } } } @@ -634,10 +809,14 @@ impl QueryRouter { /// Should we attempt to parse queries? pub fn query_parser_enabled(&self) -> bool { - match self.query_parser_enabled { + let enabled = match self.query_parser_enabled { None => self.pool_settings.query_parser_enabled, Some(value) => value, - } + }; + + debug!("Query parser enabled: {}", enabled); + + enabled } pub fn primary_reads_enabled(&self) -> bool { @@ -1066,4 +1245,32 @@ mod test { assert!(qr.infer(&simple_query("SELECT * FROM table_y WHERE another_key = 5"))); assert_eq!(qr.shard(), 0); } + + #[test] + fn test_prepared_statements() { + let stmt = "SELECT * FROM data WHERE id = $1"; + + let mut bind = BytesMut::from(&b"B"[..]); + + let mut payload = BytesMut::from(&b"\0\0"[..]); + payload.put_i16(0); + payload.put_i16(1); + payload.put_i32(1); + payload.put(&b"5"[..]); + payload.put_i16(0); + + bind.put_i32(payload.len() as i32 + 4); + bind.put(payload); + + let mut qr = QueryRouter::new(); + qr.pool_settings.automatic_sharding_key = Some("data.id".to_string()); + qr.pool_settings.shards = 3; + + assert!(qr.infer(&simple_query(stmt))); + assert_eq!(qr.placeholders.len(), 1); + + assert!(qr.infer_shard_from_bind(&bind)); + assert_eq!(qr.shard(), 2); + assert!(qr.placeholders.is_empty()); + } } diff --git a/tests/ruby/helpers/pgcat_helper.rb b/tests/ruby/helpers/pgcat_helper.rb index 544c827..c4ebab7 100644 --- a/tests/ruby/helpers/pgcat_helper.rb +++ b/tests/ruby/helpers/pgcat_helper.rb @@ -24,8 +24,9 @@ module Helpers "default_role" => "any", "pool_mode" => pool_mode, "load_balancing_mode" => lb_mode, - "primary_reads_enabled" => false, - "query_parser_enabled" => false, + "primary_reads_enabled" => true, + "query_parser_enabled" => true, + "automatic_sharding_key" => "data.id", "sharding_function" => "pg_bigint_hash", "shards" => { "0" => { "database" => "shard0", "servers" => [["localhost", primary0.port.to_s, "primary"]] }, diff --git a/tests/ruby/sharding_spec.rb b/tests/ruby/sharding_spec.rb new file mode 100644 index 0000000..4c4053d --- /dev/null +++ b/tests/ruby/sharding_spec.rb @@ -0,0 +1,51 @@ +# frozen_string_literal: true +require_relative 'spec_helper' + + +describe "Sharding" do + let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5) } + + before do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + + # Setup the sharding data + 3.times do |i| + conn.exec("SET SHARD TO '#{i}'") + conn.exec("DELETE FROM data WHERE id > 0") + end + + 18.times do |i| + i = i + 1 + conn.exec("SET SHARDING KEY TO '#{i}'") + conn.exec("INSERT INTO data (id, value) VALUES (#{i}, 'value_#{i}')") + end + end + + after do + + processes.all_databases.map(&:reset) + processes.pgcat.shutdown + end + + describe "automatic routing of extended procotol" do + it "can do it" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + conn.exec("SET SERVER ROLE TO 'auto'") + + 18.times do |i| + result = conn.exec_params("SELECT * FROM data WHERE id = $1", [i + 1]) + expect(result.ntuples).to eq(1) + end + end + + it "can do it with multiple parameters" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user")) + conn.exec("SET SERVER ROLE TO 'auto'") + + 18.times do |i| + result = conn.exec_params("SELECT * FROM data WHERE id = $1 AND id = $2", [i + 1, i + 1]) + expect(result.ntuples).to eq(1) + end + end + end +end