mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-26 18:36:28 +00:00
correct load balancing
This commit is contained in:
@@ -9,9 +9,8 @@ use tokio::net::TcpStream;
|
|||||||
|
|
||||||
use crate::errors::Error;
|
use crate::errors::Error;
|
||||||
use crate::messages::*;
|
use crate::messages::*;
|
||||||
use crate::pool::{ClientServerMap, ServerPool};
|
use crate::pool::{ClientServerMap, ConnectionPool};
|
||||||
use crate::server::Server;
|
use crate::server::Server;
|
||||||
use bb8::Pool;
|
|
||||||
|
|
||||||
/// The client state.
|
/// The client state.
|
||||||
pub struct Client {
|
pub struct Client {
|
||||||
@@ -125,7 +124,7 @@ impl Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Client loop. We handle all messages between the client and the database here.
|
/// Client loop. We handle all messages between the client and the database here.
|
||||||
pub async fn handle(&mut self, pool: Pool<ServerPool>) -> Result<(), Error> {
|
pub async fn handle(&mut self, pool: ConnectionPool) -> Result<(), Error> {
|
||||||
// Special: cancelling existing running query
|
// Special: cancelling existing running query
|
||||||
if self.cancel_mode {
|
if self.cancel_mode {
|
||||||
let (process_id, secret_key, address, port) = {
|
let (process_id, secret_key, address, port) = {
|
||||||
@@ -148,13 +147,17 @@ impl Client {
|
|||||||
loop {
|
loop {
|
||||||
// Only grab a connection once we have some traffic on the socket
|
// Only grab a connection once we have some traffic on the socket
|
||||||
// TODO: this is not the most optimal way to share servers.
|
// TODO: this is not the most optimal way to share servers.
|
||||||
let mut peek_buf = vec![0u8; 2];
|
// let mut peek_buf = vec![0u8; 2];
|
||||||
|
|
||||||
match self.read.get_mut().peek(&mut peek_buf).await {
|
// match self.read.get_mut().peek(&mut peek_buf).await {
|
||||||
Ok(_) => (),
|
// Ok(_) => (),
|
||||||
Err(_) => return Err(Error::ClientDisconnected),
|
// Err(_) => return Err(Error::ClientDisconnected),
|
||||||
};
|
// };
|
||||||
let mut proxy = pool.get().await.unwrap();
|
let message = read_message(&mut self.read).await?;
|
||||||
|
|
||||||
|
self.buffer.put(message);
|
||||||
|
|
||||||
|
let mut proxy = pool.get(None).await.unwrap().0;
|
||||||
let server = &mut *proxy;
|
let server = &mut *proxy;
|
||||||
|
|
||||||
// TODO: maybe don't do this, I don't think it's useful.
|
// TODO: maybe don't do this, I don't think it's useful.
|
||||||
@@ -164,18 +167,28 @@ impl Client {
|
|||||||
server.claim(self.process_id, self.secret_key);
|
server.claim(self.process_id, self.secret_key);
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let mut message = match read_message(&mut self.read).await {
|
let mut message = match self.buffer.len() {
|
||||||
Ok(message) => message,
|
0 => {
|
||||||
Err(err) => {
|
match read_message(&mut self.read).await {
|
||||||
if server.in_transaction() {
|
Ok(message) => message,
|
||||||
// TODO: this is what PgBouncer does
|
Err(err) => {
|
||||||
// which leads to connection thrashing.
|
if server.in_transaction() {
|
||||||
//
|
// TODO: this is what PgBouncer does
|
||||||
// I think we could issue a ROLLBACK here instead.
|
// which leads to connection thrashing.
|
||||||
server.mark_bad();
|
//
|
||||||
}
|
// I think we could issue a ROLLBACK here instead.
|
||||||
|
server.mark_bad();
|
||||||
|
}
|
||||||
|
|
||||||
return Err(err);
|
return Err(err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => {
|
||||||
|
let message = self.buffer.clone();
|
||||||
|
self.buffer.clear();
|
||||||
|
message
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
18
src/main.rs
18
src/main.rs
@@ -21,7 +21,6 @@ extern crate tokio;
|
|||||||
|
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
use bb8::Pool;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
@@ -35,12 +34,7 @@ mod server;
|
|||||||
// Support for query cancellation: this maps our process_ids and
|
// Support for query cancellation: this maps our process_ids and
|
||||||
// secret keys to the backend's.
|
// secret keys to the backend's.
|
||||||
use config::{Address, User};
|
use config::{Address, User};
|
||||||
use pool::{ClientServerMap, ReplicaPool, ServerPool};
|
use pool::{ClientServerMap, ConnectionPool};
|
||||||
|
|
||||||
//
|
|
||||||
// Poor man's config
|
|
||||||
//
|
|
||||||
const POOL_SIZE: u32 = 15;
|
|
||||||
|
|
||||||
/// Main!
|
/// Main!
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@@ -71,7 +65,6 @@ async fn main() {
|
|||||||
port: "5432".to_string(),
|
port: "5432".to_string(),
|
||||||
},
|
},
|
||||||
];
|
];
|
||||||
let num_addresses = addresses.len() as u32;
|
|
||||||
|
|
||||||
let user = User {
|
let user = User {
|
||||||
name: "lev".to_string(),
|
name: "lev".to_string(),
|
||||||
@@ -80,9 +73,7 @@ async fn main() {
|
|||||||
|
|
||||||
let database = "lev";
|
let database = "lev";
|
||||||
|
|
||||||
let replica_pool = ReplicaPool::new(addresses).await;
|
let pool = ConnectionPool::new(addresses, user, database, client_server_map.clone()).await;
|
||||||
let manager = ServerPool::new(replica_pool, user, database, client_server_map.clone());
|
|
||||||
|
|
||||||
// We are round-robining, so ideally the replicas will be equally loaded.
|
// We are round-robining, so ideally the replicas will be equally loaded.
|
||||||
// Therefore, we are allocating number of replicas * pool size of connections.
|
// Therefore, we are allocating number of replicas * pool size of connections.
|
||||||
// However, if a replica dies, the remaining replicas will share the burden,
|
// However, if a replica dies, the remaining replicas will share the burden,
|
||||||
@@ -91,11 +82,6 @@ async fn main() {
|
|||||||
// Note that failover in this case could bring down the remaining replicas, so
|
// Note that failover in this case could bring down the remaining replicas, so
|
||||||
// in certain situations, e.g. when replicas are running hot already, failover
|
// in certain situations, e.g. when replicas are running hot already, failover
|
||||||
// is not at all desirable!!
|
// is not at all desirable!!
|
||||||
let pool = Pool::builder()
|
|
||||||
.max_size(POOL_SIZE * num_addresses)
|
|
||||||
.build(manager)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let pool = pool.clone();
|
let pool = pool.clone();
|
||||||
|
|||||||
261
src/pool.rs
261
src/pool.rs
@@ -1,6 +1,6 @@
|
|||||||
/// Pooling and failover and banlist.
|
/// Pooling and failover and banlist.
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use bb8::{ManageConnection, PooledConnection};
|
use bb8::{ManageConnection, Pool, PooledConnection};
|
||||||
use chrono::naive::NaiveDateTime;
|
use chrono::naive::NaiveDateTime;
|
||||||
|
|
||||||
use crate::config::{Address, User};
|
use crate::config::{Address, User};
|
||||||
@@ -21,112 +21,119 @@ pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, Stri
|
|||||||
// 60 seconds of ban time.
|
// 60 seconds of ban time.
|
||||||
// After that, the replica will be allowed to serve traffic again.
|
// After that, the replica will be allowed to serve traffic again.
|
||||||
const BAN_TIME: i64 = 60;
|
const BAN_TIME: i64 = 60;
|
||||||
|
//
|
||||||
|
// Poor man's config
|
||||||
|
//
|
||||||
|
const POOL_SIZE: u32 = 15;
|
||||||
|
|
||||||
pub struct ServerPool {
|
|
||||||
replica_pool: ReplicaPool,
|
|
||||||
user: User,
|
|
||||||
database: String,
|
|
||||||
client_server_map: ClientServerMap,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ServerPool {
|
|
||||||
pub fn new(
|
|
||||||
replica_pool: ReplicaPool,
|
|
||||||
user: User,
|
|
||||||
database: &str,
|
|
||||||
client_server_map: ClientServerMap,
|
|
||||||
) -> ServerPool {
|
|
||||||
ServerPool {
|
|
||||||
replica_pool: replica_pool,
|
|
||||||
user: user,
|
|
||||||
database: database.to_string(),
|
|
||||||
client_server_map: client_server_map,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[async_trait]
|
|
||||||
impl ManageConnection for ServerPool {
|
|
||||||
type Connection = Server;
|
|
||||||
type Error = Error;
|
|
||||||
|
|
||||||
/// Attempts to create a new connection.
|
|
||||||
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
|
|
||||||
println!(">> Getting new connection from the pool");
|
|
||||||
let address = self.replica_pool.get();
|
|
||||||
|
|
||||||
match Server::startup(
|
|
||||||
&address.host,
|
|
||||||
&address.port,
|
|
||||||
&self.user.name,
|
|
||||||
&self.user.password,
|
|
||||||
&self.database,
|
|
||||||
self.client_server_map.clone(),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(server) => {
|
|
||||||
self.replica_pool.unban(&address);
|
|
||||||
Ok(server)
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
self.replica_pool.ban(&address);
|
|
||||||
Err(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Determines if the connection is still connected to the database.
|
|
||||||
async fn is_valid(&self, conn: &mut PooledConnection<'_, Self>) -> Result<(), Self::Error> {
|
|
||||||
let server = &mut *conn;
|
|
||||||
|
|
||||||
// Client disconnected before cleaning up
|
|
||||||
if server.in_transaction() {
|
|
||||||
return Err(Error::DirtyServer);
|
|
||||||
}
|
|
||||||
|
|
||||||
// If this fails, the connection will be closed and another will be grabbed from the pool quietly :-).
|
|
||||||
// Failover, step 1, complete.
|
|
||||||
match tokio::time::timeout(
|
|
||||||
tokio::time::Duration::from_millis(1000),
|
|
||||||
server.query("SELECT 1"),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(_) => Ok(()),
|
|
||||||
Err(_err) => {
|
|
||||||
println!(">> Unhealthy!");
|
|
||||||
self.replica_pool.ban(&server.address());
|
|
||||||
Err(Error::ServerTimeout)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Synchronously determine if the connection is no longer usable, if possible.
|
|
||||||
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
|
|
||||||
conn.is_bad()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A collection of addresses, which could either be a single primary,
|
|
||||||
/// many sharded primaries or replicas.
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct ReplicaPool {
|
pub struct ConnectionPool {
|
||||||
|
databases: Vec<Pool<ServerPool>>,
|
||||||
addresses: Vec<Address>,
|
addresses: Vec<Address>,
|
||||||
round_robin: Counter,
|
round_robin: Counter,
|
||||||
banlist: BanList,
|
banlist: BanList,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ReplicaPool {
|
impl ConnectionPool {
|
||||||
/// Create a new replica pool. Addresses must be known in advance.
|
pub async fn new(
|
||||||
pub async fn new(addresses: Vec<Address>) -> ReplicaPool {
|
addresses: Vec<Address>,
|
||||||
ReplicaPool {
|
user: User,
|
||||||
|
database: &str,
|
||||||
|
client_server_map: ClientServerMap,
|
||||||
|
) -> ConnectionPool {
|
||||||
|
let mut databases = Vec::new();
|
||||||
|
|
||||||
|
for address in &addresses {
|
||||||
|
let manager = ServerPool::new(
|
||||||
|
address.clone(),
|
||||||
|
user.clone(),
|
||||||
|
database,
|
||||||
|
client_server_map.clone(),
|
||||||
|
);
|
||||||
|
let pool = Pool::builder()
|
||||||
|
.max_size(POOL_SIZE)
|
||||||
|
.connection_timeout(std::time::Duration::from_millis(5000))
|
||||||
|
.test_on_check_out(false)
|
||||||
|
.build(manager)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
databases.push(pool);
|
||||||
|
}
|
||||||
|
|
||||||
|
ConnectionPool {
|
||||||
|
databases: databases,
|
||||||
addresses: addresses,
|
addresses: addresses,
|
||||||
round_robin: Arc::new(AtomicUsize::new(0)),
|
round_robin: Arc::new(AtomicUsize::new(0)),
|
||||||
banlist: Arc::new(Mutex::new(HashMap::new())),
|
banlist: Arc::new(Mutex::new(HashMap::new())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get a connection from the pool. Either round-robin or pick a specific one in case they are sharded.
|
||||||
|
pub async fn get(
|
||||||
|
&self,
|
||||||
|
index: Option<usize>,
|
||||||
|
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
|
||||||
|
match index {
|
||||||
|
// Asking for a specific database, must be sharded.
|
||||||
|
// No failover here.
|
||||||
|
Some(index) => {
|
||||||
|
assert!(index < self.databases.len());
|
||||||
|
match self.databases[index].get().await {
|
||||||
|
Ok(conn) => Ok((conn, self.addresses[index].clone())),
|
||||||
|
Err(err) => {
|
||||||
|
println!(">> Shard {} down: {:?}", index, err);
|
||||||
|
Err(Error::ServerTimeout)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Any database is fine, we're using round-robin here.
|
||||||
|
// Failover included if the server doesn't answer a health check.
|
||||||
|
None => {
|
||||||
|
loop {
|
||||||
|
let index =
|
||||||
|
self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases.len();
|
||||||
|
let address = self.addresses[index].clone();
|
||||||
|
|
||||||
|
if self.is_banned(&address) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we can connect
|
||||||
|
let mut conn = match self.databases[index].get().await {
|
||||||
|
Ok(conn) => conn,
|
||||||
|
Err(err) => {
|
||||||
|
println!(">> Banning replica {}, error: {:?}", index, err);
|
||||||
|
self.ban(&address);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check if this server is alive with a health check
|
||||||
|
let server = &mut *conn;
|
||||||
|
|
||||||
|
match tokio::time::timeout(
|
||||||
|
tokio::time::Duration::from_millis(1000),
|
||||||
|
server.query("SELECT 1"),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(_) => return Ok((conn, address)),
|
||||||
|
Err(_) => {
|
||||||
|
println!(
|
||||||
|
">> Banning replica {} because of failed health check",
|
||||||
|
index
|
||||||
|
);
|
||||||
|
self.ban(&address);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Ban an address (i.e. replica). It no longer will serve
|
/// Ban an address (i.e. replica). It no longer will serve
|
||||||
/// traffic for any new transactions. Existing transactions on that replica
|
/// traffic for any new transactions. Existing transactions on that replica
|
||||||
/// will finish successfully or error out to the clients.
|
/// will finish successfully or error out to the clients.
|
||||||
@@ -150,7 +157,7 @@ impl ReplicaPool {
|
|||||||
let mut guard = self.banlist.lock().unwrap();
|
let mut guard = self.banlist.lock().unwrap();
|
||||||
|
|
||||||
// Everything is banned, nothig is banned
|
// Everything is banned, nothig is banned
|
||||||
if guard.len() == self.addresses.len() {
|
if guard.len() == self.databases.len() {
|
||||||
guard.clear();
|
guard.clear();
|
||||||
drop(guard);
|
drop(guard);
|
||||||
println!(">> Unbanning all replicas.");
|
println!(">> Unbanning all replicas.");
|
||||||
@@ -173,22 +180,58 @@ impl ReplicaPool {
|
|||||||
None => false,
|
None => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Get a replica to route the query to.
|
pub struct ServerPool {
|
||||||
/// Will attempt to fetch a healthy replica. It will also
|
address: Address,
|
||||||
/// round-robin them for reasonably equal load. Round-robin is done
|
user: User,
|
||||||
/// per transaction.
|
database: String,
|
||||||
pub fn get(&self) -> Address {
|
client_server_map: ClientServerMap,
|
||||||
loop {
|
}
|
||||||
// We'll never hit a 64-bit overflow right....right? :-)
|
|
||||||
let index = self.round_robin.fetch_add(1, Ordering::SeqCst) % self.addresses.len();
|
|
||||||
|
|
||||||
let address = &self.addresses[index];
|
impl ServerPool {
|
||||||
if !self.is_banned(address) {
|
pub fn new(
|
||||||
return address.clone();
|
address: Address,
|
||||||
} else {
|
user: User,
|
||||||
continue;
|
database: &str,
|
||||||
}
|
client_server_map: ClientServerMap,
|
||||||
|
) -> ServerPool {
|
||||||
|
ServerPool {
|
||||||
|
address: address,
|
||||||
|
user: user,
|
||||||
|
database: database.to_string(),
|
||||||
|
client_server_map: client_server_map,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl ManageConnection for ServerPool {
|
||||||
|
type Connection = Server;
|
||||||
|
type Error = Error;
|
||||||
|
|
||||||
|
/// Attempts to create a new connection.
|
||||||
|
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
|
||||||
|
println!(">> Getting new connection from the pool");
|
||||||
|
|
||||||
|
Server::startup(
|
||||||
|
&self.address.host,
|
||||||
|
&self.address.port,
|
||||||
|
&self.user.name,
|
||||||
|
&self.user.password,
|
||||||
|
&self.database,
|
||||||
|
self.client_server_map.clone(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Determines if the connection is still connected to the database.
|
||||||
|
async fn is_valid(&self, _conn: &mut PooledConnection<'_, Self>) -> Result<(), Self::Error> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Synchronously determine if the connection is no longer usable, if possible.
|
||||||
|
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
|
||||||
|
conn.is_bad()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user