fixes to the banlist

This commit is contained in:
Lev Kokotov
2022-02-09 21:19:14 -08:00
parent 28c70d47b6
commit a9b2a41a9b
7 changed files with 87 additions and 10 deletions

View File

@@ -47,6 +47,7 @@ password = "sharding_user"
servers = [ servers = [
[ "127.0.0.1", 5432, "primary" ], [ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ], [ "localhost", 5432, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
] ]
# Database name (e.g. "postgres") # Database name (e.g. "postgres")
database = "shard0" database = "shard0"
@@ -56,6 +57,7 @@ database = "shard0"
servers = [ servers = [
[ "127.0.0.1", 5432, "primary" ], [ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ], [ "localhost", 5432, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
] ]
database = "shard1" database = "shard1"
@@ -64,5 +66,6 @@ database = "shard1"
servers = [ servers = [
[ "127.0.0.1", 5432, "primary" ], [ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ], [ "localhost", 5432, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
] ]
database = "shard2" database = "shard2"

View File

@@ -214,7 +214,14 @@ impl Client {
// Grab a server from the pool. // Grab a server from the pool.
// None = any shard // None = any shard
let connection = pool.get(shard, role).await.unwrap(); let connection = match pool.get(shard, role).await {
Ok(conn) => conn,
Err(err) => {
println!(">> Could not get connection from pool: {:?}", err);
return Err(err);
}
};
let mut proxy = connection.0; let mut proxy = connection.0;
let _address = connection.1; let _address = connection.1;
let server = &mut *proxy; let server = &mut *proxy;
@@ -253,10 +260,13 @@ impl Client {
match code { match code {
'Q' => { 'Q' => {
// TODO: implement retries here for read-only transactions.
server.send(original).await?; server.send(original).await?;
loop { loop {
// TODO: implement retries here for read-only transactions.
let response = server.recv().await?; let response = server.recv().await?;
match write_all_half(&mut self.write, response).await { match write_all_half(&mut self.write, response).await {
Ok(_) => (), Ok(_) => (),
Err(err) => { Err(err) => {
@@ -312,10 +322,13 @@ impl Client {
'S' => { 'S' => {
// Extended protocol, client requests sync // Extended protocol, client requests sync
self.buffer.put(&original[..]); self.buffer.put(&original[..]);
// TODO: retries for read-only transactions
server.send(self.buffer.clone()).await?; server.send(self.buffer.clone()).await?;
self.buffer.clear(); self.buffer.clear();
loop { loop {
// TODO: retries for read-only transactions
let response = server.recv().await?; let response = server.recv().await?;
match write_all_half(&mut self.write, response).await { match write_all_half(&mut self.write, response).await {
Ok(_) => (), Ok(_) => (),

View File

@@ -3,7 +3,7 @@ use tokio::fs::File;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use toml; use toml;
use std::collections::HashMap; use std::collections::{HashMap, HashSet};
use crate::errors::Error; use crate::errors::Error;
@@ -77,6 +77,21 @@ pub async fn parse(path: &str) -> Result<Config, Error> {
} }
}; };
// We use addresses as unique identifiers,
// let's make sure they are unique in the config as well.
for shard in &config.shards {
let mut dup_check = HashSet::new();
for server in &shard.1.servers {
dup_check.insert(server);
}
if dup_check.len() != shard.1.servers.len() {
println!("> Shard {} contains duplicate server configs.", &shard.0);
return Err(Error::BadConfig);
}
}
Ok(config) Ok(config)
} }

View File

@@ -8,4 +8,5 @@ pub enum Error {
// ServerTimeout, // ServerTimeout,
// DirtyServer, // DirtyServer,
BadConfig, BadConfig,
AllServersDown,
} }

View File

@@ -113,7 +113,17 @@ impl ConnectionPool {
None => 0, // TODO: pick a shard at random None => 0, // TODO: pick a shard at random
}; };
loop { let mut allowed_attempts = match role {
// Primary-specific queries get one attempt, if the primary is down,
// nothing we can do.
Some(Role::Primary) => 1,
// Replicas get to try as many times as there are replicas.
Some(Role::Replica) => self.databases[shard].len(),
None => self.databases[shard].len(),
};
while allowed_attempts > 0 {
// TODO: think about making this local, so multiple clients // TODO: think about making this local, so multiple clients
// don't compete for the same round-robin integer. // don't compete for the same round-robin integer.
// Especially since we're going to be skipping (see role selection below). // Especially since we're going to be skipping (see role selection below).
@@ -121,21 +131,27 @@ impl ConnectionPool {
self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases[shard].len(); self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases[shard].len();
let address = self.addresses[shard][index].clone(); let address = self.addresses[shard][index].clone();
if self.is_banned(&address, shard) {
continue;
}
// Make sure you're getting a primary or a replica // Make sure you're getting a primary or a replica
// as per request. // as per request.
match role { match role {
Some(role) => { Some(role) => {
if address.role != role { // If the client wants a specific role,
// we'll do our best to pick it, but if we only
// have one server in the cluster, it's probably only a primary
// (or only a replica), so the client will just get what we have.
if address.role != role && self.addresses[shard].len() > 1 {
continue; continue;
} }
} }
None => (), None => (),
}; };
if self.is_banned(&address, shard, role) {
continue;
}
allowed_attempts -= 1;
// Check if we can connect // Check if we can connect
// TODO: implement query wait timeout, i.e. time to get a conn from the pool // TODO: implement query wait timeout, i.e. time to get a conn from the pool
let mut conn = match self.databases[shard][index].get().await { let mut conn = match self.databases[shard][index].get().await {
@@ -183,6 +199,8 @@ impl ConnectionPool {
} }
} }
} }
return Err(Error::AllServersDown);
} }
/// Ban an address (i.e. replica). It no longer will serve /// Ban an address (i.e. replica). It no longer will serve
@@ -204,7 +222,14 @@ impl ConnectionPool {
/// Check if a replica can serve traffic. If all replicas are banned, /// Check if a replica can serve traffic. If all replicas are banned,
/// we unban all of them. Better to try then not to. /// we unban all of them. Better to try then not to.
pub fn is_banned(&self, address: &Address, shard: usize) -> bool { pub fn is_banned(&self, address: &Address, shard: usize, role: Option<Role>) -> bool {
// If primary is requested explicitely, it can never be banned.
if Some(Role::Primary) == role {
return false;
}
// If you're not asking for the primary,
// all databases are treated as replicas.
let mut guard = self.banlist.lock().unwrap(); let mut guard = self.banlist.lock().unwrap();
// Everything is banned = nothing is banned. // Everything is banned = nothing is banned.

View File

@@ -1,7 +1,12 @@
#/bin/bash #/bin/bash
set -e
# Setup all the shards. # Setup all the shards.
sudo service postgresql restart # sudo service postgresql restart
echo "Giving Postgres 5 seconds to start up..."
# sleep 5
psql -f query_routing_setup.sql psql -f query_routing_setup.sql
@@ -9,4 +14,6 @@ psql -h 127.0.0.1 -p 6432 -f query_routing_test_insert.sql
psql -h 127.0.0.1 -p 6432 -f query_routing_test_select.sql psql -h 127.0.0.1 -p 6432 -f query_routing_test_select.sql
psql -e -h 127.0.0.1 -p 6432 -f query_routing_test_primary_replica.sql
psql -f query_routing_test_validate.sql psql -f query_routing_test_validate.sql

View File

@@ -0,0 +1,13 @@
SET SERVER ROLE TO 'primary';
SELECT 1;
SET SERVER ROLE TO 'replica';
SELECT 1;
SET SHARDING KEY TO '1234';
SET SERVER ROLE TO 'primary';
SELECT 1;
SET SERVER ROLE TO 'replica';
SET SHARDING KEY TO '4321';
SELECT 1;