From 28c70d47b6b39b97eec629c4b62864d753dfdba2 Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Wed, 9 Feb 2022 20:02:20 -0800 Subject: [PATCH] #1 Primary/replica selection --- pgcat.toml | 20 ++++++++--------- src/client.rs | 59 +++++++++++++++++++++++++++++++++++++++++++++---- src/config.rs | 10 ++++++++- src/messages.rs | 10 ++++++--- src/pool.rs | 28 ++++++++++++++++++++++- src/server.rs | 7 +++++- 6 files changed, 114 insertions(+), 20 deletions(-) diff --git a/pgcat.toml b/pgcat.toml index 78f5570..803a342 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -43,26 +43,26 @@ password = "sharding_user" # Shard 0 [shards.0] -# [ host, port ] +# [ host, port, role ] servers = [ - [ "127.0.0.1", 5432 ], - [ "localhost", 5432 ], + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ], ] # Database name (e.g. "postgres") database = "shard0" [shards.1] -# [ host, port ] +# [ host, port, role ] servers = [ - [ "127.0.0.1", 5432 ], - [ "localhost", 5432 ], + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ], ] database = "shard1" [shards.2] -# [ host, port ] +# [ host, port, role ] servers = [ - [ "127.0.0.1", 5432 ], - [ "localhost", 5432 ], + [ "127.0.0.1", 5432, "primary" ], + [ "localhost", 5432, "replica" ], ] -database = "shard2" \ No newline at end of file +database = "shard2" diff --git a/src/client.rs b/src/client.rs index 85cfbe9..5df6077 100644 --- a/src/client.rs +++ b/src/client.rs @@ -7,6 +7,7 @@ use tokio::io::{AsyncReadExt, BufReader}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; +use crate::config::Role; use crate::errors::Error; use crate::messages::*; use crate::pool::{ClientServerMap, ConnectionPool}; @@ -14,6 +15,7 @@ use crate::server::Server; use crate::sharding::Sharder; const SHARDING_REGEX: &str = r"SET SHARDING KEY TO '[0-9]+';"; +const ROLE_REGEX: &str = r"SET SERVER ROLE TO '(PRIMARY|REPLICA)';"; /// The client state. One of these is created per client. pub struct Client { @@ -45,6 +47,9 @@ pub struct Client { // sharding regex sharding_regex: Regex, + + // role detection regex + role_regex: Regex, } impl Client { @@ -57,6 +62,7 @@ impl Client { transaction_mode: bool, ) -> Result { let sharding_regex = Regex::new(SHARDING_REGEX).unwrap(); + let role_regex = Regex::new(ROLE_REGEX).unwrap(); loop { // Could be StartupMessage or SSLRequest @@ -114,6 +120,7 @@ impl Client { secret_key: secret_key, client_server_map: client_server_map, sharding_regex: sharding_regex, + role_regex: role_regex, }); } @@ -134,6 +141,7 @@ impl Client { secret_key: secret_key, client_server_map: client_server_map, sharding_regex: sharding_regex, + role_regex: role_regex, }); } @@ -172,6 +180,8 @@ impl Client { // - if in transaction mode, this lives for the duration of one transaction. let mut shard: Option = None; + let mut role: Option = None; + loop { // Read a complete message from the client, which normally would be // either a `Q` (query) or `P` (prepare, extended protocol). @@ -182,18 +192,29 @@ impl Client { // Parse for special select shard command. // SET SHARDING KEY TO 'bigint'; - match self.select_shard(message.clone(), pool.shards()).await { + match self.select_shard(message.clone(), pool.shards()) { Some(s) => { - set_sharding_key(&mut self.write).await?; + custom_protocol_response_ok(&mut self.write, "SET SHARDING KEY").await?; shard = Some(s); continue; } None => (), }; + // Parse for special server role selection command. + // + match self.select_role(message.clone()) { + Some(r) => { + custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?; + role = Some(r); + continue; + } + None => (), + }; + // Grab a server from the pool. // None = any shard - let connection = pool.get(shard).await.unwrap(); + let connection = pool.get(shard, role).await.unwrap(); let mut proxy = connection.0; let _address = connection.1; let server = &mut *proxy; @@ -252,6 +273,7 @@ impl Client { // Release server if !server.in_transaction() && self.transaction_mode { shard = None; + role = None; break; } } @@ -311,6 +333,7 @@ impl Client { // Release server if !server.in_transaction() && self.transaction_mode { shard = None; + role = None; break; } } @@ -338,6 +361,7 @@ impl Client { if !server.in_transaction() && self.transaction_mode { println!("Releasing after copy done"); shard = None; + role = None; break; } } @@ -361,7 +385,7 @@ impl Client { /// Determine if the query is part of our special syntax, extract /// the shard key, and return the shard to query based on Postgres' /// PARTITION BY HASH function. - async fn select_shard(&mut self, mut buf: BytesMut, shards: usize) -> Option { + fn select_shard(&mut self, mut buf: BytesMut, shards: usize) -> Option { let code = buf.get_u8() as char; // Only supporting simpe protocol here, so @@ -390,4 +414,31 @@ impl Client { None } } + + // Pick a primary or a replica from the pool. + fn select_role(&mut self, mut buf: BytesMut) -> Option { + let code = buf.get_u8() as char; + + // Same story as select_shard() above. + match code { + 'Q' => (), + _ => return None, + }; + + let len = buf.get_i32(); + let query = String::from_utf8_lossy(&buf[..len as usize - 4 - 1]).to_ascii_uppercase(); + + // Copy / paste from above. If we get one more of these use cases, + // it'll be time to abstract :). + if self.role_regex.is_match(&query) { + let role = query.split("'").collect::>()[1]; + match role { + "PRIMARY" => Some(Role::Primary), + "REPLICA" => Some(Role::Replica), + _ => return None, + } + } else { + None + } + } } diff --git a/src/config.rs b/src/config.rs index 39218f7..094fd79 100644 --- a/src/config.rs +++ b/src/config.rs @@ -7,10 +7,17 @@ use std::collections::HashMap; use crate::errors::Error; +#[derive(Clone, PartialEq, Deserialize, Hash, std::cmp::Eq, Debug, Copy)] +pub enum Role { + Primary, + Replica, +} + #[derive(Clone, PartialEq, Hash, std::cmp::Eq, Debug)] pub struct Address { pub host: String, pub port: String, + pub role: Role, } #[derive(Clone, PartialEq, Hash, std::cmp::Eq, Deserialize, Debug)] @@ -32,7 +39,7 @@ pub struct General { #[derive(Deserialize, Debug, Clone)] pub struct Shard { - pub servers: Vec<(String, u16)>, + pub servers: Vec<(String, u16, String)>, pub database: String, } @@ -83,5 +90,6 @@ mod test { assert_eq!(config.general.pool_size, 15); assert_eq!(config.shards.len(), 3); assert_eq!(config.shards["1"].servers[0].0, "127.0.0.1"); + assert_eq!(config.shards["0"].servers[0].2, "primary"); } } diff --git a/src/messages.rs b/src/messages.rs index 90a6700..5f17d8d 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -141,12 +141,16 @@ pub async fn md5_password( Ok(write_all(stream, message).await?) } -/// Implements a response to our custom `SET SHARDING KEY` command. +/// Implements a response to our custom `SET SHARDING KEY` +/// and `SET SERVER ROLE` commands. /// This tells the client we're ready for the next query. -pub async fn set_sharding_key(stream: &mut OwnedWriteHalf) -> Result<(), Error> { +pub async fn custom_protocol_response_ok( + stream: &mut OwnedWriteHalf, + message: &str, +) -> Result<(), Error> { let mut res = BytesMut::with_capacity(25); - let set_complete = BytesMut::from(&"SET SHARDING KEY\0"[..]); + let set_complete = BytesMut::from(&format!("{}\0", message)[..]); let len = (set_complete.len() + 4) as i32; // CommandComplete diff --git a/src/pool.rs b/src/pool.rs index 624818a..57bc066 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -3,7 +3,7 @@ use async_trait::async_trait; use bb8::{ManageConnection, Pool, PooledConnection}; use chrono::naive::NaiveDateTime; -use crate::config::{Address, Config, User}; +use crate::config::{Address, Config, Role, User}; use crate::errors::Error; use crate::server::Server; @@ -48,9 +48,19 @@ impl ConnectionPool { let mut replica_addresses = Vec::new(); for server in &shard.servers { + let role = match server.2.as_ref() { + "primary" => Role::Primary, + "replica" => Role::Replica, + _ => { + println!("> Config error: server role can be 'primary' or 'replica', have: '{}'. Defaulting to 'replica'.", server.2); + Role::Replica + } + }; + let address = Address { host: server.0.clone(), port: server.1.to_string(), + role: role, }; let manager = ServerPool::new( @@ -93,6 +103,7 @@ impl ConnectionPool { pub async fn get( &self, shard: Option, + role: Option, ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> { // Set this to false to gain ~3-4% speed. let with_health_check = true; @@ -103,6 +114,9 @@ impl ConnectionPool { }; loop { + // TODO: think about making this local, so multiple clients + // don't compete for the same round-robin integer. + // Especially since we're going to be skipping (see role selection below). let index = self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases[shard].len(); let address = self.addresses[shard][index].clone(); @@ -111,6 +125,17 @@ impl ConnectionPool { continue; } + // Make sure you're getting a primary or a replica + // as per request. + match role { + Some(role) => { + if address.role != role { + continue; + } + } + None => (), + }; + // Check if we can connect // TODO: implement query wait timeout, i.e. time to get a conn from the pool let mut conn = match self.databases[shard][index].get().await { @@ -251,6 +276,7 @@ impl ManageConnection for ServerPool { &self.user.password, &self.database, self.client_server_map.clone(), + self.address.role, ) .await } diff --git a/src/server.rs b/src/server.rs index 0a7e31f..13f362a 100644 --- a/src/server.rs +++ b/src/server.rs @@ -8,7 +8,7 @@ use tokio::io::{AsyncReadExt, BufReader}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; -use crate::config::Address; +use crate::config::{Address, Role}; use crate::errors::Error; use crate::messages::*; use crate::ClientServerMap; @@ -48,6 +48,8 @@ pub struct Server { // Mapping of clients and servers used for query cancellation. client_server_map: ClientServerMap, + + role: Role, } impl Server { @@ -60,6 +62,7 @@ impl Server { password: &str, database: &str, client_server_map: ClientServerMap, + role: Role, ) -> Result { let mut stream = match TcpStream::connect(&format!("{}:{}", host, port)).await { Ok(stream) => stream, @@ -189,6 +192,7 @@ impl Server { data_available: false, bad: false, client_server_map: client_server_map, + role: role, }); } @@ -409,6 +413,7 @@ impl Server { Address { host: self.host.to_string(), port: self.port.to_string(), + role: self.role, } } }