diff --git a/Cargo.lock b/Cargo.lock index 1673d3a..d11410a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -322,6 +322,7 @@ dependencies = [ "bb8", "bytes", "chrono", + "log", "md-5", "num_cpus", "once_cell", @@ -330,6 +331,7 @@ dependencies = [ "serde", "serde_derive", "sha-1", + "sqlparser", "statsd", "tokio", "toml", @@ -492,6 +494,15 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83" +[[package]] +name = "sqlparser" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8f192f29f4aa49e57bebd0aa05858e0a1f32dd270af36efe49edb82cbfffab6" +dependencies = [ + "log", +] + [[package]] name = "statsd" version = "0.15.0" diff --git a/Cargo.toml b/Cargo.toml index 0db9ecd..ba78b59 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,3 +21,5 @@ regex = "1" num_cpus = "1" once_cell = "1" statsd = "0.15" +sqlparser = "0.14" +log = "0.4" diff --git a/pgcat.toml b/pgcat.toml index db5c3f2..0fa8b6a 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -81,3 +81,15 @@ database = "shard2" # replica: round-robin between replicas only without touching the primary, # primary: all queries go to the primary unless otherwise specified. default_role = "any" + + +# Query parser. If enabled, we'll attempt to parse +# every incoming query to determine if it's a read or a write. +# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write, +# we'll direct it to the primary. +query_parser_enabled = false + +# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for +# load balancing of read queries. Otherwise, the primary will only be used for write +# queries. The primary can always be explicitely selected with our custom protocol. +primary_reads_enabled = true diff --git a/src/client.rs b/src/client.rs index dd1b7cc..184d46d 100644 --- a/src/client.rs +++ b/src/client.rs @@ -10,7 +10,6 @@ use tokio::net::{ use std::collections::HashMap; -use crate::config::Role; use crate::constants::*; use crate::errors::Error; use crate::messages::*; @@ -47,10 +46,6 @@ pub struct Client { // to connect and cancel a query. client_server_map: ClientServerMap, - // Unless client specifies, route queries to the servers that have this role, - // e.g. primary or replicas or any. - default_server_role: Option, - // Client parameters, e.g. user, client_encoding, etc. #[allow(dead_code)] parameters: HashMap, @@ -67,7 +62,6 @@ impl Client { mut stream: TcpStream, client_server_map: ClientServerMap, transaction_mode: bool, - default_server_role: Option, server_info: BytesMut, stats: Reporter, ) -> Result { @@ -126,7 +120,6 @@ impl Client { process_id: process_id, secret_key: secret_key, client_server_map: client_server_map, - default_server_role: default_server_role, parameters: parameters, stats: stats, }); @@ -148,7 +141,6 @@ impl Client { process_id: process_id, secret_key: secret_key, client_server_map: client_server_map, - default_server_role: default_server_role, parameters: HashMap::new(), stats: stats, }); @@ -162,7 +154,11 @@ impl Client { } /// Client loop. We handle all messages between the client and the database here. - pub async fn handle(&mut self, mut pool: ConnectionPool) -> Result<(), Error> { + pub async fn handle( + &mut self, + mut pool: ConnectionPool, + mut query_router: QueryRouter, + ) -> Result<(), Error> { // The client wants to cancel a query it has issued previously. if self.cancel_mode { let (process_id, secret_key, address, port) = { @@ -191,8 +187,6 @@ impl Client { return Ok(Server::cancel(&address, &port, process_id, secret_key).await?); } - let mut query_router = QueryRouter::new(self.default_server_role, pool.shards()); - // Our custom protocol loop. // We expect the client to either start a transaction with regular queries // or issue commands for our sharding and server selection protocols. @@ -222,6 +216,11 @@ impl Client { continue; } + // Attempt to parse the query to determine where it should go + if query_router.query_parser_enabled() && query_router.role() == None { + query_router.infer_role(message.clone()); + } + // Grab a server from the pool: the client issued a regular query. let connection = match pool.get(query_router.shard(), query_router.role()).await { Ok(conn) => conn, diff --git a/src/config.rs b/src/config.rs index 1a3f22b..8f7c45f 100644 --- a/src/config.rs +++ b/src/config.rs @@ -13,6 +13,24 @@ pub enum Role { Replica, } +impl PartialEq> for Role { + fn eq(&self, other: &Option) -> bool { + match other { + None => true, + Some(role) => *self == *role, + } + } +} + +impl PartialEq for Option { + fn eq(&self, other: &Role) -> bool { + match *self { + None => true, + Some(role) => role == *other, + } + } +} + #[derive(Clone, PartialEq, Hash, std::cmp::Eq, Debug)] pub struct Address { pub host: String, @@ -47,6 +65,8 @@ pub struct Shard { #[derive(Deserialize, Debug, Clone)] pub struct QueryRouter { pub default_role: String, + pub query_parser_enabled: bool, + pub primary_reads_enabled: bool, } #[derive(Deserialize, Debug, Clone)] diff --git a/src/main.rs b/src/main.rs index cd74162..1669331 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,11 +16,13 @@ extern crate async_trait; extern crate bb8; extern crate bytes; +extern crate log; extern crate md5; extern crate num_cpus; extern crate once_cell; extern crate serde; extern crate serde_derive; +extern crate sqlparser; extern crate statsd; extern crate tokio; extern crate toml; @@ -47,6 +49,7 @@ mod stats; // secret keys to the backend's. use config::Role; use pool::{ClientServerMap, ConnectionPool}; +use query_router::QueryRouter; use stats::{Collector, Reporter}; /// Main! @@ -118,6 +121,8 @@ async fn main() { return; } }; + let primary_reads_enabled = config.query_router.primary_reads_enabled; + let query_parser_enabled = config.query_router.query_parser_enabled; let server_info = match pool.validate().await { Ok(info) => info, @@ -155,7 +160,6 @@ async fn main() { socket, client_server_map, transaction_mode, - default_server_role, server_info, reporter, ) @@ -164,7 +168,14 @@ async fn main() { Ok(mut client) => { println!(">> Client {:?} authenticated successfully!", addr); - match client.handle(pool).await { + let query_router = QueryRouter::new( + default_server_role, + pool.shards(), + primary_reads_enabled, + query_parser_enabled, + ); + + match client.handle(pool, query_router).await { Ok(()) => { let duration = chrono::offset::Utc::now().naive_utc() - start; diff --git a/src/pool.rs b/src/pool.rs index a5e32d1..6a190f1 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -25,7 +25,6 @@ pub struct ConnectionPool { banlist: BanList, healthcheck_timeout: u64, ban_time: i64, - pool_size: u32, stats: Reporter, } @@ -47,12 +46,12 @@ impl ConnectionPool { .collect::>(); shard_ids.sort_by_key(|k| k.parse::().unwrap()); - for shard in shard_ids { - let shard = &config.shards[&shard]; + for shard_idx in shard_ids { + let shard = &config.shards[&shard_idx]; let mut pools = Vec::new(); let mut replica_addresses = Vec::new(); - for (idx, server) in shard.servers.iter().enumerate() { + for server in shard.servers.iter() { let role = match server.2.as_ref() { "primary" => Role::Primary, "replica" => Role::Replica, @@ -66,7 +65,7 @@ impl ConnectionPool { host: server.0.clone(), port: server.1.to_string(), role: role, - shard: idx, + shard: shard_idx.parse::().unwrap(), }; let manager = ServerPool::new( @@ -106,7 +105,6 @@ impl ConnectionPool { banlist: Arc::new(Mutex::new(banlist)), healthcheck_timeout: config.general.healthcheck_timeout, ban_time: config.general.ban_time, - pool_size: config.general.pool_size, stats: stats, } } @@ -120,12 +118,12 @@ impl ConnectionPool { let mut server_infos = Vec::new(); for shard in 0..self.shards() { - for _ in 0..self.replicas(shard) { + for _ in 0..self.servers(shard) { let connection = match self.get(shard, None).await { Ok(conn) => conn, Err(err) => { - println!("> Shard {} down or misconfigured.", shard); - return Err(err); + println!("> Shard {} down or misconfigured: {:?}", shard, err); + continue; } }; @@ -152,8 +150,6 @@ impl ConnectionPool { shard: usize, role: Option, ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> { - // Set this to false to gain ~3-4% speed. - let with_health_check = true; let now = Instant::now(); // We are waiting for a server now. @@ -161,25 +157,6 @@ impl ConnectionPool { let addresses = &self.addresses[shard]; - // Make sure if a specific role is requested, it's available in the pool. - match role { - Some(role) => { - let role_count = addresses.iter().filter(|&db| db.role == role).count(); - - if role_count == 0 { - println!( - ">> Error: Role '{:?}' requested, but none are configured.", - role - ); - - return Err(Error::AllServersDown); - } - } - - // Any role should be present. - _ => (), - }; - let mut allowed_attempts = match role { // Primary-specific queries get one attempt, if the primary is down, // nothing we should do about it I think. It's dangerous to retry @@ -188,9 +165,22 @@ impl ConnectionPool { // Replicas get to try as many times as there are replicas // and connections in the pool. - _ => self.databases[shard].len() * self.pool_size as usize, + _ => addresses.len(), }; + let exists = match role { + Some(role) => addresses.iter().filter(|addr| addr.role == role).count() > 0, + None => true, + }; + + if !exists { + log::error!( + "ConnectionPool::get Requested role {:?}, but none is configured.", + role + ); + return Err(Error::BadConfig); + } + while allowed_attempts > 0 { // Round-robin each client's queries. // If a client only sends one query and then disconnects, it doesn't matter @@ -200,23 +190,18 @@ impl ConnectionPool { let address = &addresses[index]; // Make sure you're getting a primary or a replica - // as per request. - match role { - Some(role) => { - // Find the specific role the client wants in the pool. - if address.role != role { - continue; - } - } - None => (), - }; - - if self.is_banned(address, shard, role) { + // as per request. If no specific role is requested, the first + // available will be chosen. + if address.role != role { continue; } allowed_attempts -= 1; + if self.is_banned(address, shard, role) { + continue; + } + // Check if we can connect let mut conn = match self.databases[shard][index].get().await { Ok(conn) => conn, @@ -227,12 +212,6 @@ impl ConnectionPool { } }; - if !with_health_check { - self.stats.checkout_time(now.elapsed().as_micros()); - self.stats.client_active(); - return Ok((conn, address.clone())); - } - // // Check if this server is alive with a health check let server = &mut *conn; @@ -299,17 +278,21 @@ impl ConnectionPool { /// Check if a replica can serve traffic. If all replicas are banned, /// we unban all of them. Better to try then not to. pub fn is_banned(&self, address: &Address, shard: usize, role: Option) -> bool { - // If primary is requested explicitely, it can never be banned. - if Some(Role::Primary) == role { - return false; - } + let replicas_available = match role { + Some(Role::Replica) => self.addresses[shard] + .iter() + .filter(|addr| addr.role == Role::Replica) + .count(), + None => self.addresses[shard].len(), + Some(Role::Primary) => return false, // Primary cannot be banned. + }; // If you're not asking for the primary, // all databases are treated as replicas. let mut guard = self.banlist.lock().unwrap(); // Everything is banned = nothing is banned. - if guard[shard].len() == self.databases[shard].len() { + if guard[shard].len() == replicas_available { guard[shard].clear(); drop(guard); println!(">> Unbanning all replicas."); @@ -337,7 +320,7 @@ impl ConnectionPool { self.databases.len() } - pub fn replicas(&self, shard: usize) -> usize { + pub fn servers(&self, shard: usize) -> usize { self.addresses[shard].len() } } diff --git a/src/query_router.rs b/src/query_router.rs index aa73d56..f5bdc99 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -1,8 +1,11 @@ -use bytes::{Buf, BytesMut}; /// Route queries automatically based on explicitely requested /// or implied query characteristics. +use bytes::{Buf, BytesMut}; use once_cell::sync::OnceCell; use regex::{Regex, RegexBuilder}; +use sqlparser::ast::Statement::{Query, StartTransaction}; +use sqlparser::dialect::PostgreSqlDialect; +use sqlparser::parser::Parser; use crate::config::Role; use crate::sharding::Sharder; @@ -26,6 +29,12 @@ pub struct QueryRouter { // Should we be talking to a primary or a replica? active_role: Option, + + // Include the primary into the replica pool? + primary_reads_enabled: bool, + + // Should we try to parse queries? + query_parser_enabled: bool, } impl QueryRouter { @@ -54,13 +63,20 @@ impl QueryRouter { a && b } - pub fn new(default_server_role: Option, shards: usize) -> QueryRouter { + pub fn new( + default_server_role: Option, + shards: usize, + primary_reads_enabled: bool, + query_parser_enabled: bool, + ) -> QueryRouter { QueryRouter { default_server_role: default_server_role, shards: shards, active_role: default_server_role, active_shard: None, + primary_reads_enabled: primary_reads_enabled, + query_parser_enabled: query_parser_enabled, } } @@ -109,7 +125,7 @@ impl QueryRouter { } } - // Pick a primary or a replica from the pool. + /// Pick a primary or a replica from the pool. pub fn select_role(&mut self, mut buf: BytesMut) -> bool { let code = buf.get_u8() as char; @@ -150,6 +166,75 @@ impl QueryRouter { } } + /// Try to infer which server to connect to based on the contents of the query. + pub fn infer_role(&mut self, mut buf: BytesMut) -> bool { + let code = buf.get_u8() as char; + let len = buf.get_i32() as usize; + + let query = match code { + 'Q' => String::from_utf8_lossy(&buf[..len - 5]).to_string(), + 'P' => { + let mut start = 0; + let mut end; + + // Skip the name of the prepared statement. + while buf[start] != 0 && start < buf.len() { + start += 1; + } + start += 1; // Skip terminating null + + // Find the end of the prepared stmt (\0) + end = start; + while buf[end] != 0 && end < buf.len() { + end += 1; + } + + let query = String::from_utf8_lossy(&buf[start..end]).to_string(); + + query.replace("$", "") // Remove placeholders turning them into "values" + } + _ => return false, + }; + + let ast = match Parser::parse_sql(&PostgreSqlDialect {}, &query) { + Ok(ast) => ast, + Err(err) => { + log::debug!( + "QueryParser::infer_role could not parse query, error: {:?}, query: {}", + err, + query + ); + return false; + } + }; + + if ast.len() == 0 { + return false; + } + + match ast[0] { + // All transactions go to the primary, probably a write. + StartTransaction { .. } => { + self.active_role = Some(Role::Primary); + } + + // Likely a read-only query + Query { .. } => { + self.active_role = match self.primary_reads_enabled { + false => Some(Role::Replica), // If primary should not be receiving reads, use a replica. + true => None, // Any server role is fine in this case. + } + } + + // Likely a write + _ => { + self.active_role = Some(Role::Primary); + } + }; + + true + } + /// Get the current desired server role we should be talking to. pub fn role(&self) -> Option { self.active_role @@ -169,6 +254,11 @@ impl QueryRouter { self.active_role = self.default_server_role; self.active_shard = None; } + + /// Should we attempt to parse queries? + pub fn query_parser_enabled(&self) -> bool { + self.query_parser_enabled + } } #[cfg(test)] @@ -182,7 +272,7 @@ mod test { let default_server_role: Option = None; let shards = 5; - let mut query_router = QueryRouter::new(default_server_role, shards); + let mut query_router = QueryRouter::new(default_server_role, shards, false, false); // Build the special syntax query. let mut message = BytesMut::new(); @@ -205,7 +295,7 @@ mod test { let default_server_role: Option = None; let shards = 5; - let mut query_router = QueryRouter::new(default_server_role, shards); + let mut query_router = QueryRouter::new(default_server_role, shards, false, false); // Build the special syntax query. let mut message = BytesMut::new(); @@ -229,7 +319,7 @@ mod test { let default_server_role: Option = None; let shards = 5; - let query_router = QueryRouter::new(default_server_role, shards); + let query_router = QueryRouter::new(default_server_role, shards, false, false); assert_eq!(query_router.shard(), 0); assert_eq!(query_router.role(), None); @@ -241,7 +331,7 @@ mod test { let default_server_role: Option = None; let shards = 5; - let mut query_router = QueryRouter::new(default_server_role, shards); + let mut query_router = QueryRouter::new(default_server_role, shards, false, false); // Build the special syntax query. let mut message = BytesMut::new(); @@ -256,4 +346,97 @@ mod test { assert_eq!(query_router.select_shard(message.clone()), false); assert_eq!(query_router.select_role(message.clone()), false); } + + #[test] + fn test_infer_role_replica() { + QueryRouter::setup(); + + let default_server_role: Option = None; + let shards = 5; + + let mut query_router = QueryRouter::new(default_server_role, shards, false, false); + + let queries = vec![ + BytesMut::from(&b"SELECT * FROM items WHERE id = 5\0"[..]), + BytesMut::from(&b"SELECT id, name, value FROM items INNER JOIN prices ON item.id = prices.item_id\0"[..]), + BytesMut::from(&b"WITH t AS (SELECT * FROM items) SELECT * FROM t\0"[..]), + ]; + + for query in &queries { + let mut res = BytesMut::from(&b"Q"[..]); + res.put_i32(query.len() as i32 + 4); + res.put(query.clone()); + + // It's a recognized query + assert!(query_router.infer_role(res)); + assert_eq!(query_router.role(), Some(Role::Replica)); + } + } + + #[test] + fn test_infer_role_primary() { + QueryRouter::setup(); + + let default_server_role: Option = None; + let shards = 5; + + let mut query_router = QueryRouter::new(default_server_role, shards, false, false); + + let queries = vec![ + BytesMut::from(&b"UPDATE items SET name = 'pumpkin' WHERE id = 5\0"[..]), + BytesMut::from(&b"INSERT INTO items (id, name) VALUES (5, 'pumpkin')\0"[..]), + BytesMut::from(&b"DELETE FROM items WHERE id = 5\0"[..]), + BytesMut::from(&b"BEGIN\0"[..]), // Transaction start + ]; + + for query in &queries { + let mut res = BytesMut::from(&b"Q"[..]); + res.put_i32(query.len() as i32 + 4); + res.put(query.clone()); + + // It's a recognized query + assert!(query_router.infer_role(res)); + assert_eq!(query_router.role(), Some(Role::Primary)); + } + } + + #[test] + fn test_infer_role_primary_reads_enabled() { + QueryRouter::setup(); + + let default_server_role: Option = None; + let shards = 5; + + let mut query_router = QueryRouter::new(default_server_role, shards, true, false); + + let query = BytesMut::from(&b"SELECT * FROM items WHERE id = 5\0"[..]); + let mut res = BytesMut::from(&b"Q"[..]); + res.put_i32(query.len() as i32 + 4); + res.put(query.clone()); + + assert!(query_router.infer_role(res)); + assert_eq!(query_router.role(), None); + } + + #[test] + fn test_infer_role_parse_prepared() { + QueryRouter::setup(); + + let default_server_role: Option = None; + let shards = 5; + + let mut query_router = QueryRouter::new(default_server_role, shards, false, false); + + let prepared_stmt = BytesMut::from( + &b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..], + ); + let mut res = BytesMut::from(&b"P"[..]); + res.put_i32(prepared_stmt.len() as i32 + 4 + 1 + 2); + res.put_u8(0); + res.put(prepared_stmt); + res.put_i16(0); + + assert!(query_router.infer_role(res)); + assert_eq!(query_router.role(), Some(Role::Replica)); + } }