diff --git a/pgcat.toml b/pgcat.toml index 803a342..ffcf722 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -47,6 +47,7 @@ password = "sharding_user" servers = [ [ "127.0.0.1", 5432, "primary" ], [ "localhost", 5432, "replica" ], + # [ "127.0.1.1", 5432, "replica" ], ] # Database name (e.g. "postgres") database = "shard0" @@ -56,6 +57,7 @@ database = "shard0" servers = [ [ "127.0.0.1", 5432, "primary" ], [ "localhost", 5432, "replica" ], + # [ "127.0.1.1", 5432, "replica" ], ] database = "shard1" @@ -64,5 +66,6 @@ database = "shard1" servers = [ [ "127.0.0.1", 5432, "primary" ], [ "localhost", 5432, "replica" ], + # [ "127.0.1.1", 5432, "replica" ], ] database = "shard2" diff --git a/src/client.rs b/src/client.rs index 5df6077..7d7e27a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -214,7 +214,14 @@ impl Client { // Grab a server from the pool. // None = any shard - let connection = pool.get(shard, role).await.unwrap(); + let connection = match pool.get(shard, role).await { + Ok(conn) => conn, + Err(err) => { + println!(">> Could not get connection from pool: {:?}", err); + return Err(err); + } + }; + let mut proxy = connection.0; let _address = connection.1; let server = &mut *proxy; @@ -253,10 +260,13 @@ impl Client { match code { 'Q' => { + // TODO: implement retries here for read-only transactions. server.send(original).await?; loop { + // TODO: implement retries here for read-only transactions. let response = server.recv().await?; + match write_all_half(&mut self.write, response).await { Ok(_) => (), Err(err) => { @@ -312,10 +322,13 @@ impl Client { 'S' => { // Extended protocol, client requests sync self.buffer.put(&original[..]); + + // TODO: retries for read-only transactions server.send(self.buffer.clone()).await?; self.buffer.clear(); loop { + // TODO: retries for read-only transactions let response = server.recv().await?; match write_all_half(&mut self.write, response).await { Ok(_) => (), diff --git a/src/config.rs b/src/config.rs index 094fd79..fe9206d 100644 --- a/src/config.rs +++ b/src/config.rs @@ -3,7 +3,7 @@ use tokio::fs::File; use tokio::io::AsyncReadExt; use toml; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use crate::errors::Error; @@ -77,6 +77,21 @@ pub async fn parse(path: &str) -> Result { } }; + // We use addresses as unique identifiers, + // let's make sure they are unique in the config as well. + for shard in &config.shards { + let mut dup_check = HashSet::new(); + + for server in &shard.1.servers { + dup_check.insert(server); + } + + if dup_check.len() != shard.1.servers.len() { + println!("> Shard {} contains duplicate server configs.", &shard.0); + return Err(Error::BadConfig); + } + } + Ok(config) } diff --git a/src/errors.rs b/src/errors.rs index 3dcbf74..1fc26bb 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -8,4 +8,5 @@ pub enum Error { // ServerTimeout, // DirtyServer, BadConfig, + AllServersDown, } diff --git a/src/pool.rs b/src/pool.rs index 57bc066..49e13e5 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -113,7 +113,17 @@ impl ConnectionPool { None => 0, // TODO: pick a shard at random }; - loop { + let mut allowed_attempts = match role { + // Primary-specific queries get one attempt, if the primary is down, + // nothing we can do. + Some(Role::Primary) => 1, + + // Replicas get to try as many times as there are replicas. + Some(Role::Replica) => self.databases[shard].len(), + None => self.databases[shard].len(), + }; + + while allowed_attempts > 0 { // TODO: think about making this local, so multiple clients // don't compete for the same round-robin integer. // Especially since we're going to be skipping (see role selection below). @@ -121,21 +131,27 @@ impl ConnectionPool { self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases[shard].len(); let address = self.addresses[shard][index].clone(); - if self.is_banned(&address, shard) { - continue; - } - // Make sure you're getting a primary or a replica // as per request. match role { Some(role) => { - if address.role != role { + // If the client wants a specific role, + // we'll do our best to pick it, but if we only + // have one server in the cluster, it's probably only a primary + // (or only a replica), so the client will just get what we have. + if address.role != role && self.addresses[shard].len() > 1 { continue; } } None => (), }; + if self.is_banned(&address, shard, role) { + continue; + } + + allowed_attempts -= 1; + // Check if we can connect // TODO: implement query wait timeout, i.e. time to get a conn from the pool let mut conn = match self.databases[shard][index].get().await { @@ -183,6 +199,8 @@ impl ConnectionPool { } } } + + return Err(Error::AllServersDown); } /// Ban an address (i.e. replica). It no longer will serve @@ -204,7 +222,14 @@ impl ConnectionPool { /// Check if a replica can serve traffic. If all replicas are banned, /// we unban all of them. Better to try then not to. - pub fn is_banned(&self, address: &Address, shard: usize) -> bool { + pub fn is_banned(&self, address: &Address, shard: usize, role: Option) -> bool { + // If primary is requested explicitely, it can never be banned. + if Some(Role::Primary) == role { + return false; + } + + // If you're not asking for the primary, + // all databases are treated as replicas. let mut guard = self.banlist.lock().unwrap(); // Everything is banned = nothing is banned. diff --git a/tests/sharding/query_routing.sh b/tests/sharding/query_routing.sh index d1b2b84..78aaa60 100644 --- a/tests/sharding/query_routing.sh +++ b/tests/sharding/query_routing.sh @@ -1,7 +1,12 @@ #/bin/bash +set -e # Setup all the shards. -sudo service postgresql restart +# sudo service postgresql restart + +echo "Giving Postgres 5 seconds to start up..." + +# sleep 5 psql -f query_routing_setup.sql @@ -9,4 +14,6 @@ psql -h 127.0.0.1 -p 6432 -f query_routing_test_insert.sql psql -h 127.0.0.1 -p 6432 -f query_routing_test_select.sql +psql -e -h 127.0.0.1 -p 6432 -f query_routing_test_primary_replica.sql + psql -f query_routing_test_validate.sql \ No newline at end of file diff --git a/tests/sharding/query_routing_test_primary_replica.sql b/tests/sharding/query_routing_test_primary_replica.sql new file mode 100644 index 0000000..06a734c --- /dev/null +++ b/tests/sharding/query_routing_test_primary_replica.sql @@ -0,0 +1,13 @@ +SET SERVER ROLE TO 'primary'; +SELECT 1; + +SET SERVER ROLE TO 'replica'; +SELECT 1; + +SET SHARDING KEY TO '1234'; +SET SERVER ROLE TO 'primary'; +SELECT 1; + +SET SERVER ROLE TO 'replica'; +SET SHARDING KEY TO '4321'; +SELECT 1; \ No newline at end of file