fixes to the banlist

This commit is contained in:
Lev Kokotov
2022-02-09 21:19:14 -08:00
parent 28c70d47b6
commit a9b2a41a9b
7 changed files with 87 additions and 10 deletions

View File

@@ -47,6 +47,7 @@ password = "sharding_user"
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
]
# Database name (e.g. "postgres")
database = "shard0"
@@ -56,6 +57,7 @@ database = "shard0"
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
]
database = "shard1"
@@ -64,5 +66,6 @@ database = "shard1"
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
]
database = "shard2"

View File

@@ -214,7 +214,14 @@ impl Client {
// Grab a server from the pool.
// None = any shard
let connection = pool.get(shard, role).await.unwrap();
let connection = match pool.get(shard, role).await {
Ok(conn) => conn,
Err(err) => {
println!(">> Could not get connection from pool: {:?}", err);
return Err(err);
}
};
let mut proxy = connection.0;
let _address = connection.1;
let server = &mut *proxy;
@@ -253,10 +260,13 @@ impl Client {
match code {
'Q' => {
// TODO: implement retries here for read-only transactions.
server.send(original).await?;
loop {
// TODO: implement retries here for read-only transactions.
let response = server.recv().await?;
match write_all_half(&mut self.write, response).await {
Ok(_) => (),
Err(err) => {
@@ -312,10 +322,13 @@ impl Client {
'S' => {
// Extended protocol, client requests sync
self.buffer.put(&original[..]);
// TODO: retries for read-only transactions
server.send(self.buffer.clone()).await?;
self.buffer.clear();
loop {
// TODO: retries for read-only transactions
let response = server.recv().await?;
match write_all_half(&mut self.write, response).await {
Ok(_) => (),

View File

@@ -3,7 +3,7 @@ use tokio::fs::File;
use tokio::io::AsyncReadExt;
use toml;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use crate::errors::Error;
@@ -77,6 +77,21 @@ pub async fn parse(path: &str) -> Result<Config, Error> {
}
};
// We use addresses as unique identifiers,
// let's make sure they are unique in the config as well.
for shard in &config.shards {
let mut dup_check = HashSet::new();
for server in &shard.1.servers {
dup_check.insert(server);
}
if dup_check.len() != shard.1.servers.len() {
println!("> Shard {} contains duplicate server configs.", &shard.0);
return Err(Error::BadConfig);
}
}
Ok(config)
}

View File

@@ -8,4 +8,5 @@ pub enum Error {
// ServerTimeout,
// DirtyServer,
BadConfig,
AllServersDown,
}

View File

@@ -113,7 +113,17 @@ impl ConnectionPool {
None => 0, // TODO: pick a shard at random
};
loop {
let mut allowed_attempts = match role {
// Primary-specific queries get one attempt, if the primary is down,
// nothing we can do.
Some(Role::Primary) => 1,
// Replicas get to try as many times as there are replicas.
Some(Role::Replica) => self.databases[shard].len(),
None => self.databases[shard].len(),
};
while allowed_attempts > 0 {
// TODO: think about making this local, so multiple clients
// don't compete for the same round-robin integer.
// Especially since we're going to be skipping (see role selection below).
@@ -121,21 +131,27 @@ impl ConnectionPool {
self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases[shard].len();
let address = self.addresses[shard][index].clone();
if self.is_banned(&address, shard) {
continue;
}
// Make sure you're getting a primary or a replica
// as per request.
match role {
Some(role) => {
if address.role != role {
// If the client wants a specific role,
// we'll do our best to pick it, but if we only
// have one server in the cluster, it's probably only a primary
// (or only a replica), so the client will just get what we have.
if address.role != role && self.addresses[shard].len() > 1 {
continue;
}
}
None => (),
};
if self.is_banned(&address, shard, role) {
continue;
}
allowed_attempts -= 1;
// Check if we can connect
// TODO: implement query wait timeout, i.e. time to get a conn from the pool
let mut conn = match self.databases[shard][index].get().await {
@@ -183,6 +199,8 @@ impl ConnectionPool {
}
}
}
return Err(Error::AllServersDown);
}
/// Ban an address (i.e. replica). It no longer will serve
@@ -204,7 +222,14 @@ impl ConnectionPool {
/// Check if a replica can serve traffic. If all replicas are banned,
/// we unban all of them. Better to try then not to.
pub fn is_banned(&self, address: &Address, shard: usize) -> bool {
pub fn is_banned(&self, address: &Address, shard: usize, role: Option<Role>) -> bool {
// If primary is requested explicitely, it can never be banned.
if Some(Role::Primary) == role {
return false;
}
// If you're not asking for the primary,
// all databases are treated as replicas.
let mut guard = self.banlist.lock().unwrap();
// Everything is banned = nothing is banned.

View File

@@ -1,7 +1,12 @@
#/bin/bash
set -e
# Setup all the shards.
sudo service postgresql restart
# sudo service postgresql restart
echo "Giving Postgres 5 seconds to start up..."
# sleep 5
psql -f query_routing_setup.sql
@@ -9,4 +14,6 @@ psql -h 127.0.0.1 -p 6432 -f query_routing_test_insert.sql
psql -h 127.0.0.1 -p 6432 -f query_routing_test_select.sql
psql -e -h 127.0.0.1 -p 6432 -f query_routing_test_primary_replica.sql
psql -f query_routing_test_validate.sql

View File

@@ -0,0 +1,13 @@
SET SERVER ROLE TO 'primary';
SELECT 1;
SET SERVER ROLE TO 'replica';
SELECT 1;
SET SHARDING KEY TO '1234';
SET SERVER ROLE TO 'primary';
SELECT 1;
SET SERVER ROLE TO 'replica';
SET SHARDING KEY TO '4321';
SELECT 1;