diff --git a/src/client.rs b/src/client.rs index 0f3fd31..dd1b7cc 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,8 +2,6 @@ /// We are pretending to the server in this scenario, /// and this module implements that. use bytes::{Buf, BufMut, BytesMut}; -use once_cell::sync::OnceCell; -use regex::Regex; use tokio::io::{AsyncReadExt, BufReader}; use tokio::net::{ tcp::{OwnedReadHalf, OwnedWriteHalf}, @@ -17,16 +15,10 @@ use crate::constants::*; use crate::errors::Error; use crate::messages::*; use crate::pool::{ClientServerMap, ConnectionPool}; +use crate::query_router::QueryRouter; use crate::server::Server; -use crate::sharding::Sharder; use crate::stats::Reporter; -pub const SHARDING_REGEX: &str = r"SET SHARDING KEY TO '[0-9]+';"; -pub const ROLE_REGEX: &str = r"SET SERVER ROLE TO '(PRIMARY|REPLICA)';"; - -pub static SHARDING_REGEX_RE: OnceCell = OnceCell::new(); -pub static ROLE_REGEX_RE: OnceCell = OnceCell::new(); - /// The client state. One of these is created per client. pub struct Client { // The reads are buffered (8K by default). @@ -199,15 +191,11 @@ impl Client { return Ok(Server::cancel(&address, &port, process_id, secret_key).await?); } - // Active shard we're talking to. - // The lifetime of this depends on the pool mode: - // - if in session mode, this lives until the client disconnects, - // - if in transaction mode, this lives for the duration of one transaction. - let mut shard: Option = None; - - // Active database role we want to talk to, e.g. primary or replica. - let mut role: Option = self.default_server_role; + let mut query_router = QueryRouter::new(self.default_server_role, pool.shards()); + // Our custom protocol loop. + // We expect the client to either start a transaction with regular queries + // or issue commands for our sharding and server selection protocols. loop { // Read a complete message from the client, which normally would be // either a `Q` (query) or `P` (prepare, extended protocol). @@ -218,32 +206,31 @@ impl Client { // Parse for special select shard command. // SET SHARDING KEY TO 'bigint'; - match self.select_shard(message.clone(), pool.shards()) { - Some(s) => { - custom_protocol_response_ok(&mut self.write, "SET SHARDING KEY").await?; - shard = Some(s); - continue; - } - None => (), - }; + if query_router.select_shard(message.clone()) { + custom_protocol_response_ok( + &mut self.write, + &format!("SET SHARD TO {}", query_router.shard()), + ) + .await?; + continue; + } // Parse for special server role selection command. // SET SERVER ROLE TO '(primary|replica)'; - match self.select_role(message.clone()) { - Some(r) => { - custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?; - role = Some(r); - continue; - } - None => (), - }; + if query_router.select_role(message.clone()) { + custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?; + continue; + } - // Grab a server from the pool. - let connection = match pool.get(shard, role).await { + // Grab a server from the pool: the client issued a regular query. + let connection = match pool.get(query_router.shard(), query_router.role()).await { Ok(conn) => conn, Err(err) => { println!(">> Could not get connection from pool: {:?}", err); - return Err(err); + error_response(&mut self.write, "could not get connection from the pool") + .await?; + query_router.reset(); + continue; } }; @@ -264,11 +251,8 @@ impl Client { Err(err) => { // Client disconnected without warning. if server.in_transaction() { - // TODO: this is what PgBouncer does - // which leads to connection thrashing. - // - // I think we could issue a ROLLBACK here instead. - // server.mark_bad(); + // Client left dirty server. Clean up and proceed + // without thrashing this connection. server.query("ROLLBACK; DISCARD ALL;").await?; } @@ -328,8 +312,7 @@ impl Client { // Report this client as idle. self.stats.client_idle(); - shard = None; - role = self.default_server_role; + query_router.reset(); break; } @@ -414,8 +397,7 @@ impl Client { if self.transaction_mode { self.stats.client_idle(); - shard = None; - role = self.default_server_role; + query_router.reset(); break; } @@ -450,8 +432,7 @@ impl Client { self.stats.transaction(); if self.transaction_mode { - shard = None; - role = self.default_server_role; + query_router.reset(); break; } @@ -476,77 +457,4 @@ impl Client { let mut guard = self.client_server_map.lock().unwrap(); guard.remove(&(self.process_id, self.secret_key)); } - - /// Determine if the query is part of our special syntax, extract - /// the shard key, and return the shard to query based on Postgres' - /// PARTITION BY HASH function. - fn select_shard(&self, mut buf: BytesMut, shards: usize) -> Option { - let code = buf.get_u8() as char; - - // Only supporting simpe protocol here, so - // one would have to execute something like this: - // psql -c "SET SHARDING KEY TO '1234'" - // after sanitizing the value manually, which can be just done with an - // int parser, e.g. `let key = "1234".parse::().unwrap()`. - match code { - 'Q' => (), - _ => return None, - }; - - let len = buf.get_i32(); - let query = String::from_utf8_lossy(&buf[..len as usize - 4 - 1]).to_ascii_uppercase(); // Don't read the ternminating null - - let rgx = match SHARDING_REGEX_RE.get() { - Some(r) => r, - None => return None, - }; - - if rgx.is_match(&query) { - let shard = query.split("'").collect::>()[1]; - - match shard.parse::() { - Ok(shard) => { - let sharder = Sharder::new(shards); - Some(sharder.pg_bigint_hash(shard)) - } - - Err(_) => None, - } - } else { - None - } - } - - // Pick a primary or a replica from the pool. - fn select_role(&self, mut buf: BytesMut) -> Option { - let code = buf.get_u8() as char; - - // Same story as select_shard() above. - match code { - 'Q' => (), - _ => return None, - }; - - let len = buf.get_i32(); - let query = String::from_utf8_lossy(&buf[..len as usize - 4 - 1]).to_ascii_uppercase(); - - let rgx = match ROLE_REGEX_RE.get() { - Some(r) => r, - None => return None, - }; - - // Copy / paste from above. If we get one more of these use cases, - // it'll be time to abstract :). - if rgx.is_match(&query) { - let role = query.split("'").collect::>()[1]; - - match role { - "PRIMARY" => Some(Role::Primary), - "REPLICA" => Some(Role::Replica), - _ => return None, - } - } else { - None - } - } } diff --git a/src/main.rs b/src/main.rs index 324f901..cd74162 100644 --- a/src/main.rs +++ b/src/main.rs @@ -25,7 +25,6 @@ extern crate statsd; extern crate tokio; extern crate toml; -use regex::Regex; use tokio::net::TcpListener; use tokio::signal; @@ -39,6 +38,7 @@ mod constants; mod errors; mod messages; mod pool; +mod query_router; mod server; mod sharding; mod stats; @@ -54,12 +54,11 @@ use stats::{Collector, Reporter}; async fn main() { println!("> Welcome to PgCat! Meow."); - client::SHARDING_REGEX_RE - .set(Regex::new(client::SHARDING_REGEX).unwrap()) - .unwrap(); - client::ROLE_REGEX_RE - .set(Regex::new(client::ROLE_REGEX).unwrap()) - .unwrap(); + // Prepare regexes + if !query_router::QueryRouter::setup() { + println!("> Could not setup query router."); + return; + } let config = match config::parse("pgcat.toml").await { Ok(config) => config, diff --git a/src/messages.rs b/src/messages.rs index ec5ac66..3b5914d 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -185,6 +185,50 @@ pub async fn custom_protocol_response_ok( write_all_half(stream, res).await } +/// Send a custom error message to the client. +/// Tell the client we are ready for the next query and no rollback is necessary. +/// Docs on error codes: https://www.postgresql.org/docs/12/errcodes-appendix.html +pub async fn error_response(stream: &mut OwnedWriteHalf, message: &str) -> Result<(), Error> { + let mut error = BytesMut::new(); + + // Error level + error.put_u8(b'S'); + error.put_slice(&b"FATAL\0"[..]); + + // Error level (non-translatable) + error.put_u8(b'V'); + error.put_slice(&b"FATAL\0"[..]); + + // Error code: not sure how much this matters. + error.put_u8(b'C'); + error.put_slice(&b"58000\0"[..]); // system_error, see Appendix A. + + // The short error message. + error.put_u8(b'M'); + error.put_slice(&format!("{}\0", message).as_bytes()); + + // No more fields follow. + error.put_u8(0); + + // Ready for query, no rollback needed (I = idle). + let mut ready_for_query = BytesMut::new(); + + ready_for_query.put_u8(b'Z'); + ready_for_query.put_i32(5); + ready_for_query.put_u8(b'I'); + + // Compose the two message reply. + let mut res = BytesMut::with_capacity(error.len() + ready_for_query.len() + 5); + + res.put_u8(b'E'); + res.put_i32(error.len() as i32 + 4); + + res.put(error); + res.put(ready_for_query); + + Ok(write_all_half(stream, res).await?) +} + /// Write all data in the buffer to the TcpStream. pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Error> { match stream.write_all(&buf).await { diff --git a/src/pool.rs b/src/pool.rs index d348eec..a5e32d1 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -121,7 +121,7 @@ impl ConnectionPool { for shard in 0..self.shards() { for _ in 0..self.replicas(shard) { - let connection = match self.get(Some(shard), None).await { + let connection = match self.get(shard, None).await { Ok(conn) => conn, Err(err) => { println!("> Shard {} down or misconfigured.", shard); @@ -149,18 +149,13 @@ impl ConnectionPool { /// Get a connection from the pool. pub async fn get( &mut self, - shard: Option, + shard: usize, role: Option, ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> { // Set this to false to gain ~3-4% speed. let with_health_check = true; let now = Instant::now(); - let shard = match shard { - Some(shard) => shard, - None => 0, // TODO: pick a shard at random - }; - // We are waiting for a server now. self.stats.client_waiting(); @@ -208,11 +203,8 @@ impl ConnectionPool { // as per request. match role { Some(role) => { - // If the client wants a specific role, - // we'll do our best to pick it, but if we only - // have one server in the cluster, it's probably only a primary - // (or only a replica), so the client will just get what we have. - if address.role != role && addresses.len() > 1 { + // Find the specific role the client wants in the pool. + if address.role != role { continue; } } diff --git a/src/query_router.rs b/src/query_router.rs new file mode 100644 index 0000000..aa73d56 --- /dev/null +++ b/src/query_router.rs @@ -0,0 +1,259 @@ +use bytes::{Buf, BytesMut}; +/// Route queries automatically based on explicitely requested +/// or implied query characteristics. +use once_cell::sync::OnceCell; +use regex::{Regex, RegexBuilder}; + +use crate::config::Role; +use crate::sharding::Sharder; + +const SHARDING_REGEX: &str = r"SET SHARDING KEY TO '[0-9]+';"; +const ROLE_REGEX: &str = r"SET SERVER ROLE TO '(PRIMARY|REPLICA)';"; + +static SHARDING_REGEX_RE: OnceCell = OnceCell::new(); +static ROLE_REGEX_RE: OnceCell = OnceCell::new(); + +pub struct QueryRouter { + // By default, queries go here, unless we have better information + // about what the client wants. + default_server_role: Option, + + // Number of shards in the cluster. + shards: usize, + + // Which shard we should be talking to right now. + active_shard: Option, + + // Should we be talking to a primary or a replica? + active_role: Option, +} + +impl QueryRouter { + pub fn setup() -> bool { + // Compile our query routing regexes early, so we only do it once. + let a = match SHARDING_REGEX_RE.set( + RegexBuilder::new(SHARDING_REGEX) + .case_insensitive(true) + .build() + .unwrap(), + ) { + Ok(_) => true, + Err(_) => false, + }; + + let b = match ROLE_REGEX_RE.set( + RegexBuilder::new(ROLE_REGEX) + .case_insensitive(true) + .build() + .unwrap(), + ) { + Ok(_) => true, + Err(_) => false, + }; + + a && b + } + + pub fn new(default_server_role: Option, shards: usize) -> QueryRouter { + QueryRouter { + default_server_role: default_server_role, + shards: shards, + + active_role: default_server_role, + active_shard: None, + } + } + + /// Determine if the query is part of our special syntax, extract + /// the shard key, and return the shard to query based on Postgres' + /// PARTITION BY HASH function. + pub fn select_shard(&mut self, mut buf: BytesMut) -> bool { + let code = buf.get_u8() as char; + + // Only supporting simpe protocol here, so + // one would have to execute something like this: + // psql -c "SET SHARDING KEY TO '1234'" + // after sanitizing the value manually, which can be just done with an + // int parser, e.g. `let key = "1234".parse::().unwrap()`. + match code { + 'Q' => (), + _ => return false, + }; + + let len = buf.get_i32(); + let query = String::from_utf8_lossy(&buf[..len as usize - 4 - 1]); // Don't read the ternminating null + + let rgx = match SHARDING_REGEX_RE.get() { + Some(r) => r, + None => return false, + }; + + if rgx.is_match(&query) { + let shard = query.split("'").collect::>()[1]; + + match shard.parse::() { + Ok(shard) => { + let sharder = Sharder::new(self.shards); + self.active_shard = Some(sharder.pg_bigint_hash(shard)); + + true + } + + // The shard must be a valid integer. Our regex won't let anything else pass, + // so this code will never run, but Rust can't know that, so we have to handle this + // case anyway. + Err(_) => false, + } + } else { + false + } + } + + // Pick a primary or a replica from the pool. + pub fn select_role(&mut self, mut buf: BytesMut) -> bool { + let code = buf.get_u8() as char; + + // Same story as select_shard() above. + match code { + 'Q' => (), + _ => return false, + }; + + let len = buf.get_i32(); + let query = String::from_utf8_lossy(&buf[..len as usize - 4 - 1]).to_ascii_uppercase(); + + let rgx = match ROLE_REGEX_RE.get() { + Some(r) => r, + None => return false, + }; + + // Copy / paste from above. If we get one more of these use cases, + // it'll be time to abstract :). + if rgx.is_match(&query) { + let role = query.split("'").collect::>()[1]; + + match role { + "PRIMARY" => { + self.active_role = Some(Role::Primary); + true + } + "REPLICA" => { + self.active_role = Some(Role::Replica); + true + } + + // Our regex won't let this case happen, but Rust can't know that. + _ => false, + } + } else { + false + } + } + + /// Get the current desired server role we should be talking to. + pub fn role(&self) -> Option { + self.active_role + } + + /// Get desired shard we should be talking to. + pub fn shard(&self) -> usize { + match self.active_shard { + Some(shard) => shard, + None => 0, // TODO: pick random shard + } + } + + /// Reset the router back to defaults. + /// This must be called at the end of every transaction in transaction mode. + pub fn reset(&mut self) { + self.active_role = self.default_server_role; + self.active_shard = None; + } +} + +#[cfg(test)] +mod test { + use super::*; + use bytes::BufMut; + + #[test] + fn test_select_shard() { + QueryRouter::setup(); + + let default_server_role: Option = None; + let shards = 5; + let mut query_router = QueryRouter::new(default_server_role, shards); + + // Build the special syntax query. + let mut message = BytesMut::new(); + let query = BytesMut::from(&b"SET SHARDING KEY TO '13';\0"[..]); + + message.put_u8(b'Q'); // Query + message.put_i32(query.len() as i32 + 4); + message.put_slice(&query[..]); + + assert!(query_router.select_shard(message)); + assert_eq!(query_router.shard(), 3); // See sharding.rs (we are using 5 shards on purpose in this test) + + query_router.reset(); + assert_eq!(query_router.shard(), 0); + } + + #[test] + fn test_select_replica() { + QueryRouter::setup(); + + let default_server_role: Option = None; + let shards = 5; + let mut query_router = QueryRouter::new(default_server_role, shards); + + // Build the special syntax query. + let mut message = BytesMut::new(); + let query = BytesMut::from(&b"SET SERVER ROLE TO 'replica';\0"[..]); + + message.put_u8(b'Q'); // Query + message.put_i32(query.len() as i32 + 4); + message.put_slice(&query[..]); + + assert!(query_router.select_role(message)); + assert_eq!(query_router.role(), Some(Role::Replica)); + + query_router.reset(); + + assert_eq!(query_router.role(), default_server_role); + } + + #[test] + fn test_defaults() { + QueryRouter::setup(); + + let default_server_role: Option = None; + let shards = 5; + let query_router = QueryRouter::new(default_server_role, shards); + + assert_eq!(query_router.shard(), 0); + assert_eq!(query_router.role(), None); + } + + #[test] + fn test_incorrect_syntax() { + QueryRouter::setup(); + + let default_server_role: Option = None; + let shards = 5; + let mut query_router = QueryRouter::new(default_server_role, shards); + + // Build the special syntax query. + let mut message = BytesMut::new(); + + // Typo! + let query = BytesMut::from(&b"SET SERVER RLE TO 'replica';\0"[..]); + + message.put_u8(b'Q'); // Query + message.put_i32(query.len() as i32 + 4); + message.put_slice(&query[..]); + + assert_eq!(query_router.select_shard(message.clone()), false); + assert_eq!(query_router.select_role(message.clone()), false); + } +} diff --git a/tests/sharding/partition_hash_test_setup.sql b/tests/sharding/partition_hash_test_setup.sql index e802ead..5e91020 100644 --- a/tests/sharding/partition_hash_test_setup.sql +++ b/tests/sharding/partition_hash_test_setup.sql @@ -23,4 +23,4 @@ SELECT * FROM shard_0 ORDER BY id LIMIT 10; SELECT * FROM shard_1 ORDER BY id LIMIT 10; SELECT * FROM shard_2 ORDER BY id LIMIT 10; SELECT * FROM shard_3 ORDER BY id LIMIT 10; -SELECT * FROM shard_4 ORDER BY id LIMIT 10; \ No newline at end of file +SELECT * FROM shard_4 ORDER BY id LIMIT 10; diff --git a/tests/sharding/query_routing_setup.sql b/tests/sharding/query_routing_setup.sql index d4e766d..8ecfc58 100644 --- a/tests/sharding/query_routing_setup.sql +++ b/tests/sharding/query_routing_setup.sql @@ -58,4 +58,4 @@ GRANT ALL ON TABLE data TO sharding_user; \c shard2 GRANT ALL ON SCHEMA public TO sharding_user; -GRANT ALL ON TABLE data TO sharding_user; \ No newline at end of file +GRANT ALL ON TABLE data TO sharding_user; diff --git a/tests/sharding/query_routing_test_insert.sql b/tests/sharding/query_routing_test_insert.sql index 97d5bab..ff8a326 100644 --- a/tests/sharding/query_routing_test_insert.sql +++ b/tests/sharding/query_routing_test_insert.sql @@ -1,3 +1,5 @@ +\set ON_ERROR_STOP on + SET SHARDING KEY TO '1'; INSERT INTO data (id, value) VALUES (1, 'value_1'); @@ -44,4 +46,10 @@ SET SHARDING KEY TO '15'; INSERT INTO data (id, value) VALUES (15, 'value_1'); SET SHARDING KEY TO '16'; -INSERT INTO data (id, value) VALUES (16, 'value_1'); \ No newline at end of file +INSERT INTO data (id, value) VALUES (16, 'value_1'); + +set sharding key to '17'; +INSERT INTO data (id, value) VALUES (17, 'value_1'); + +SeT SHaRDInG KeY to '18'; +INSERT INTO data (id, value) VALUES (18, 'value_1'); diff --git a/tests/sharding/query_routing_test_primary_replica.sql b/tests/sharding/query_routing_test_primary_replica.sql index 358b073..db05ba7 100644 --- a/tests/sharding/query_routing_test_primary_replica.sql +++ b/tests/sharding/query_routing_test_primary_replica.sql @@ -1,3 +1,5 @@ +\set ON_ERROR_STOP on + SET SERVER ROLE TO 'primary'; SET SHARDING KEY TO '1'; INSERT INTO data (id, value) VALUES (1, 'value_1'); @@ -88,6 +90,8 @@ SELECT * FROM data WHERE id = 9; --- +\set ON_ERROR_STOP on + SET SERVER ROLE TO 'primary'; SET SHARDING KEY TO '10'; INSERT INTO data (id, value) VALUES (10, 'value_1'); @@ -143,3 +147,7 @@ SELECT 1; SET SERVER ROLE TO 'replica'; SELECT 1; + +set server role to 'replica'; +SeT SeRver Role TO 'PrImARY'; +select 1; diff --git a/tests/sharding/query_routing_test_select.sql b/tests/sharding/query_routing_test_select.sql index c577803..1b30fdf 100644 --- a/tests/sharding/query_routing_test_select.sql +++ b/tests/sharding/query_routing_test_select.sql @@ -1,3 +1,5 @@ +\set ON_ERROR_STOP on + SET SHARDING KEY TO '1'; SELECT * FROM data WHERE id = 1; @@ -44,4 +46,4 @@ SET SHARDING KEY TO '15'; SELECT * FROM data WHERE id = 15; SET SHARDING KEY TO '16'; -SELECT * FROM data WHERE id = 16; \ No newline at end of file +SELECT * FROM data WHERE id = 16; diff --git a/tests/sharding/query_routing_test_validate.sql b/tests/sharding/query_routing_test_validate.sql index 5ef9a56..1647087 100644 --- a/tests/sharding/query_routing_test_validate.sql +++ b/tests/sharding/query_routing_test_validate.sql @@ -8,4 +8,4 @@ SELECT * FROM data; \c shard2 -SELECT * FROM data; \ No newline at end of file +SELECT * FROM data;