mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 17:36:28 +00:00
fixes to the banlist
This commit is contained in:
@@ -214,7 +214,14 @@ impl Client {
|
||||
|
||||
// Grab a server from the pool.
|
||||
// None = any shard
|
||||
let connection = pool.get(shard, role).await.unwrap();
|
||||
let connection = match pool.get(shard, role).await {
|
||||
Ok(conn) => conn,
|
||||
Err(err) => {
|
||||
println!(">> Could not get connection from pool: {:?}", err);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let mut proxy = connection.0;
|
||||
let _address = connection.1;
|
||||
let server = &mut *proxy;
|
||||
@@ -253,10 +260,13 @@ impl Client {
|
||||
|
||||
match code {
|
||||
'Q' => {
|
||||
// TODO: implement retries here for read-only transactions.
|
||||
server.send(original).await?;
|
||||
|
||||
loop {
|
||||
// TODO: implement retries here for read-only transactions.
|
||||
let response = server.recv().await?;
|
||||
|
||||
match write_all_half(&mut self.write, response).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
@@ -312,10 +322,13 @@ impl Client {
|
||||
'S' => {
|
||||
// Extended protocol, client requests sync
|
||||
self.buffer.put(&original[..]);
|
||||
|
||||
// TODO: retries for read-only transactions
|
||||
server.send(self.buffer.clone()).await?;
|
||||
self.buffer.clear();
|
||||
|
||||
loop {
|
||||
// TODO: retries for read-only transactions
|
||||
let response = server.recv().await?;
|
||||
match write_all_half(&mut self.write, response).await {
|
||||
Ok(_) => (),
|
||||
|
||||
@@ -3,7 +3,7 @@ use tokio::fs::File;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use toml;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use crate::errors::Error;
|
||||
|
||||
@@ -77,6 +77,21 @@ pub async fn parse(path: &str) -> Result<Config, Error> {
|
||||
}
|
||||
};
|
||||
|
||||
// We use addresses as unique identifiers,
|
||||
// let's make sure they are unique in the config as well.
|
||||
for shard in &config.shards {
|
||||
let mut dup_check = HashSet::new();
|
||||
|
||||
for server in &shard.1.servers {
|
||||
dup_check.insert(server);
|
||||
}
|
||||
|
||||
if dup_check.len() != shard.1.servers.len() {
|
||||
println!("> Shard {} contains duplicate server configs.", &shard.0);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
|
||||
@@ -8,4 +8,5 @@ pub enum Error {
|
||||
// ServerTimeout,
|
||||
// DirtyServer,
|
||||
BadConfig,
|
||||
AllServersDown,
|
||||
}
|
||||
|
||||
39
src/pool.rs
39
src/pool.rs
@@ -113,7 +113,17 @@ impl ConnectionPool {
|
||||
None => 0, // TODO: pick a shard at random
|
||||
};
|
||||
|
||||
loop {
|
||||
let mut allowed_attempts = match role {
|
||||
// Primary-specific queries get one attempt, if the primary is down,
|
||||
// nothing we can do.
|
||||
Some(Role::Primary) => 1,
|
||||
|
||||
// Replicas get to try as many times as there are replicas.
|
||||
Some(Role::Replica) => self.databases[shard].len(),
|
||||
None => self.databases[shard].len(),
|
||||
};
|
||||
|
||||
while allowed_attempts > 0 {
|
||||
// TODO: think about making this local, so multiple clients
|
||||
// don't compete for the same round-robin integer.
|
||||
// Especially since we're going to be skipping (see role selection below).
|
||||
@@ -121,21 +131,27 @@ impl ConnectionPool {
|
||||
self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases[shard].len();
|
||||
let address = self.addresses[shard][index].clone();
|
||||
|
||||
if self.is_banned(&address, shard) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Make sure you're getting a primary or a replica
|
||||
// as per request.
|
||||
match role {
|
||||
Some(role) => {
|
||||
if address.role != role {
|
||||
// If the client wants a specific role,
|
||||
// we'll do our best to pick it, but if we only
|
||||
// have one server in the cluster, it's probably only a primary
|
||||
// (or only a replica), so the client will just get what we have.
|
||||
if address.role != role && self.addresses[shard].len() > 1 {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
};
|
||||
|
||||
if self.is_banned(&address, shard, role) {
|
||||
continue;
|
||||
}
|
||||
|
||||
allowed_attempts -= 1;
|
||||
|
||||
// Check if we can connect
|
||||
// TODO: implement query wait timeout, i.e. time to get a conn from the pool
|
||||
let mut conn = match self.databases[shard][index].get().await {
|
||||
@@ -183,6 +199,8 @@ impl ConnectionPool {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Err(Error::AllServersDown);
|
||||
}
|
||||
|
||||
/// Ban an address (i.e. replica). It no longer will serve
|
||||
@@ -204,7 +222,14 @@ impl ConnectionPool {
|
||||
|
||||
/// Check if a replica can serve traffic. If all replicas are banned,
|
||||
/// we unban all of them. Better to try then not to.
|
||||
pub fn is_banned(&self, address: &Address, shard: usize) -> bool {
|
||||
pub fn is_banned(&self, address: &Address, shard: usize, role: Option<Role>) -> bool {
|
||||
// If primary is requested explicitely, it can never be banned.
|
||||
if Some(Role::Primary) == role {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If you're not asking for the primary,
|
||||
// all databases are treated as replicas.
|
||||
let mut guard = self.banlist.lock().unwrap();
|
||||
|
||||
// Everything is banned = nothing is banned.
|
||||
|
||||
Reference in New Issue
Block a user