diff --git a/src/client.rs b/src/client.rs index 3aac72c..4c556da 100644 --- a/src/client.rs +++ b/src/client.rs @@ -88,6 +88,9 @@ pub struct Client { /// Used to notify clients about an impending shutdown shutdown: Receiver<()>, + + // Sharding key column position + sharding_key_column: Option, } /// Client entrypoint. @@ -505,6 +508,7 @@ where application_name: application_name.to_string(), shutdown, connected_to_server: false, + sharding_key_column: None, }); } @@ -539,6 +543,7 @@ where application_name: String::from("undefined"), shutdown, connected_to_server: false, + sharding_key_column: None, }); } @@ -724,6 +729,13 @@ where show_response(&mut self.write, "primary reads", &value).await?; continue; } + + // COPY .. SHARDING_KEY_COLUMN .. + Some((Command::StartShardedCopy, value)) => { + custom_protocol_response_ok(&mut self.write, "SHARDED_COPY").await?; + self.sharding_key_column = Some(value.parse::().unwrap()); + continue; + } }; debug!("Waiting for connection from pool"); @@ -792,7 +804,7 @@ where // If the client is in session mode, no more custom protocol // commands will be accepted. loop { - let mut message = if message.len() == 0 { + let message = if message.len() == 0 { trace!("Waiting for message inside transaction or in session mode"); match read_message(&mut self.read).await { @@ -811,154 +823,11 @@ where msg }; - // The message will be forwarded to the server intact. We still would like to - // parse it below to figure out what to do with it. - let original = message.clone(); + match self.handle_message(&pool, server, &address, message).await? { + Some(done) => if done { break; }, + None => return Ok(()), + }; - let code = message.get_u8() as char; - let _len = message.get_i32() as usize; - - trace!("Message: {}", code); - - match code { - // ReadyForQuery - 'Q' => { - debug!("Sending query to server"); - - self.send_and_receive_loop(code, original, server, &address, &pool) - .await?; - - if !server.in_transaction() { - // Report transaction executed statistics. - self.stats.transaction(self.process_id, address.id); - - // Release server back to the pool if we are in transaction mode. - // If we are in session mode, we keep the server until the client disconnects. - if self.transaction_mode { - break; - } - } - } - - // Terminate - 'X' => { - server.checkin_cleanup().await?; - self.release(); - - return Ok(()); - } - - // Parse - // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. - 'P' => { - self.buffer.put(&original[..]); - } - - // Bind - // The placeholder's replacements are here, e.g. 'user@email.com' and 'true' - 'B' => { - self.buffer.put(&original[..]); - } - - // Describe - // Command a client can issue to describe a previously prepared named statement. - 'D' => { - self.buffer.put(&original[..]); - } - - // Execute - // Execute a prepared statement prepared in `P` and bound in `B`. - 'E' => { - self.buffer.put(&original[..]); - } - - // Sync - // Frontend (client) is asking for the query result now. - 'S' => { - debug!("Sending query to server"); - - self.buffer.put(&original[..]); - - // Clone after freeze does not allocate - let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; - - // Almost certainly true - if first_message_code == 'P' { - // Message layout - // P followed by 32 int followed by null-terminated statement name - // So message code should be in offset 0 of the buffer, first character - // in prepared statement name would be index 5 - let first_char_in_name = *self.buffer.get(5).unwrap_or(&0); - if first_char_in_name != 0 { - // This is a named prepared statement - // Server connection state will need to be cleared at checkin - server.mark_dirty(); - } - } - - self.send_and_receive_loop( - code, - self.buffer.clone(), - server, - &address, - &pool, - ) - .await?; - - self.buffer.clear(); - - if !server.in_transaction() { - self.stats.transaction(self.process_id, address.id); - - // Release server back to the pool if we are in transaction mode. - // If we are in session mode, we keep the server until the client disconnects. - if self.transaction_mode { - break; - } - } - } - - // CopyData - 'd' => { - // Forward the data to the server, - // don't buffer it since it can be rather large. - self.send_server_message(server, original, &address, &pool) - .await?; - } - - // CopyDone or CopyFail - // Copy is done, successfully or not. - 'c' | 'f' => { - self.send_server_message(server, original, &address, &pool) - .await?; - - let response = self.receive_server_message(server, &address, &pool).await?; - - match write_all_half(&mut self.write, response).await { - Ok(_) => (), - Err(err) => { - server.mark_bad(); - return Err(err); - } - }; - - if !server.in_transaction() { - self.stats.transaction(self.process_id, address.id); - - // Release server back to the pool if we are in transaction mode. - // If we are in session mode, we keep the server until the client disconnects. - if self.transaction_mode { - break; - } - } - } - - // Some unexpected message. We either did not implement the protocol correctly - // or this is not a Postgres client we're talking to. - _ => { - error!("Unexpected code: {}", code); - } - } } // The server is no longer bound to us, we can't cancel it's queries anymore. @@ -972,6 +841,156 @@ where } } + async fn handle_message(&mut self, pool: &ConnectionPool, server: &mut Server, address: &Address, mut message: BytesMut) -> Result, Error> { + let original = message.clone(); + let code = message.get_u8() as char; + let _len = message.get_i32() as usize; + + trace!("Message: {}", code); + + match code { + // ReadyForQuery + 'Q' => { + debug!("Sending query to server"); + + self.send_and_receive_loop(code, original, server, &address, &pool) + .await?; + + if !server.in_transaction() { + // Report transaction executed statistics. + self.stats.transaction(self.process_id, address.id); + + // Release server back to the pool if we are in transaction mode. + // If we are in session mode, we keep the server until the client disconnects. + if self.transaction_mode { + return Ok(Some(true)); + } + } + } + + // Terminate + 'X' => { + server.checkin_cleanup().await?; + self.release(); + + return Ok(None); + } + + // Parse + // The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`. + 'P' => { + self.buffer.put(&original[..]); + } + + // Bind + // The placeholder's replacements are here, e.g. 'user@email.com' and 'true' + 'B' => { + self.buffer.put(&original[..]); + } + + // Describe + // Command a client can issue to describe a previously prepared named statement. + 'D' => { + self.buffer.put(&original[..]); + } + + // Execute + // Execute a prepared statement prepared in `P` and bound in `B`. + 'E' => { + self.buffer.put(&original[..]); + } + + // Sync + // Frontend (client) is asking for the query result now. + 'S' => { + debug!("Sending query to server"); + + self.buffer.put(&original[..]); + + // Clone after freeze does not allocate + let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; + + // Almost certainly true + if first_message_code == 'P' { + // Message layout + // P followed by 32 int followed by null-terminated statement name + // So message code should be in offset 0 of the buffer, first character + // in prepared statement name would be index 5 + let first_char_in_name = *self.buffer.get(5).unwrap_or(&0); + if first_char_in_name != 0 { + // This is a named prepared statement + // Server connection state will need to be cleared at checkin + server.mark_dirty(); + } + } + + self.send_and_receive_loop( + code, + self.buffer.clone(), + server, + &address, + &pool, + ) + .await?; + + self.buffer.clear(); + + if !server.in_transaction() { + self.stats.transaction(self.process_id, address.id); + + // Release server back to the pool if we are in transaction mode. + // If we are in session mode, we keep the server until the client disconnects. + if self.transaction_mode { + return Ok(Some(true)); + } + } + } + + // CopyData + 'd' => { + // Forward the data to the server, + // don't buffer it since it can be rather large. + self.send_server_message(server, original, &address, &pool) + .await?; + } + + // CopyDone or CopyFail + // Copy is done, successfully or not. + 'c' | 'f' => { + self.send_server_message(server, original, &address, &pool) + .await?; + + let response = self.receive_server_message(server, &address, &pool).await?; + + match write_all_half(&mut self.write, response).await { + Ok(_) => (), + Err(err) => { + server.mark_bad(); + return Err(err); + } + }; + + if !server.in_transaction() { + self.stats.transaction(self.process_id, address.id); + + // Release server back to the pool if we are in transaction mode. + // If we are in session mode, we keep the server until the client disconnects. + if self.transaction_mode { + return Ok(Some(true)); + } + } + } + + // Some unexpected message. We either did not implement the protocol correctly + // or this is not a Postgres client we're talking to. + _ => { + error!("Unexpected code: {}", code); + } + }; + + Ok(Some(false)) + } + /// Release the server from the client: it can't cancel its queries anymore. pub fn release(&self) { let mut guard = self.client_server_map.lock(); diff --git a/src/query_router.rs b/src/query_router.rs index f9d5f0b..9ed4924 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -13,7 +13,7 @@ use crate::pool::PoolSettings; use crate::sharding::Sharder; /// Regexes used to parse custom commands. -const CUSTOM_SQL_REGEXES: [&str; 7] = [ +const CUSTOM_SQL_REGEXES: [&str; 8] = [ r"(?i)^ *SET SHARDING KEY TO '?([0-9]+)'? *;? *$", r"(?i)^ *SET SHARD TO '?([0-9]+|ANY)'? *;? *$", r"(?i)^ *SHOW SHARD *;? *$", @@ -21,6 +21,7 @@ const CUSTOM_SQL_REGEXES: [&str; 7] = [ r"(?i)^ *SHOW SERVER ROLE *;? *$", r"(?i)^ *SET PRIMARY READS TO '?(on|off|default)'? *;? *$", r"(?i)^ *SHOW PRIMARY READS *;? *$", + r"(?i)^ *SHARDED_COPY '?([0-9]+)'? *;? *$", ]; /// Custom commands. @@ -33,6 +34,7 @@ pub enum Command { ShowServerRole, SetPrimaryReads, ShowPrimaryReads, + StartShardedCopy, } /// Quickly test for match when a query is received. @@ -57,6 +59,9 @@ pub struct QueryRouter { /// Pool configuration. pool_settings: PoolSettings, + + // Sharding key column + sharding_key_column: Option, } impl QueryRouter { @@ -98,6 +103,7 @@ impl QueryRouter { query_parser_enabled: false, primary_reads_enabled: false, pool_settings: PoolSettings::default(), + sharding_key_column: None, } } @@ -145,6 +151,7 @@ impl QueryRouter { 4 => Command::ShowServerRole, 5 => Command::SetPrimaryReads, 6 => Command::ShowPrimaryReads, + 7 => Command::StartShardedCopy, _ => unreachable!(), }; @@ -152,7 +159,8 @@ impl QueryRouter { Command::SetShardingKey | Command::SetShard | Command::SetServerRole - | Command::SetPrimaryReads => { + | Command::SetPrimaryReads + | Command::StartShardedCopy => { // Capture value. I know this re-runs the regex engine, but I haven't // figured out a better way just yet. I think I can write a single Regex // that matches all 5 custom SQL patterns, but maybe that's not very legible? @@ -204,6 +212,13 @@ impl QueryRouter { }; } + Command::StartShardedCopy => { + self.sharding_key_column = match value.parse::() { + Ok(value) => Some(value), + Err(_) => return None, + } + } + Command::SetServerRole => { self.active_role = match value.to_ascii_lowercase().as_ref() { "primary" => {