mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
#1 Primary/replica selection
This commit is contained in:
20
pgcat.toml
20
pgcat.toml
@@ -43,26 +43,26 @@ password = "sharding_user"
|
||||
# Shard 0
|
||||
[shards.0]
|
||||
|
||||
# [ host, port ]
|
||||
# [ host, port, role ]
|
||||
servers = [
|
||||
[ "127.0.0.1", 5432 ],
|
||||
[ "localhost", 5432 ],
|
||||
[ "127.0.0.1", 5432, "primary" ],
|
||||
[ "localhost", 5432, "replica" ],
|
||||
]
|
||||
# Database name (e.g. "postgres")
|
||||
database = "shard0"
|
||||
|
||||
[shards.1]
|
||||
# [ host, port ]
|
||||
# [ host, port, role ]
|
||||
servers = [
|
||||
[ "127.0.0.1", 5432 ],
|
||||
[ "localhost", 5432 ],
|
||||
[ "127.0.0.1", 5432, "primary" ],
|
||||
[ "localhost", 5432, "replica" ],
|
||||
]
|
||||
database = "shard1"
|
||||
|
||||
[shards.2]
|
||||
# [ host, port ]
|
||||
# [ host, port, role ]
|
||||
servers = [
|
||||
[ "127.0.0.1", 5432 ],
|
||||
[ "localhost", 5432 ],
|
||||
[ "127.0.0.1", 5432, "primary" ],
|
||||
[ "localhost", 5432, "replica" ],
|
||||
]
|
||||
database = "shard2"
|
||||
database = "shard2"
|
||||
|
||||
@@ -7,6 +7,7 @@ use tokio::io::{AsyncReadExt, BufReader};
|
||||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::config::Role;
|
||||
use crate::errors::Error;
|
||||
use crate::messages::*;
|
||||
use crate::pool::{ClientServerMap, ConnectionPool};
|
||||
@@ -14,6 +15,7 @@ use crate::server::Server;
|
||||
use crate::sharding::Sharder;
|
||||
|
||||
const SHARDING_REGEX: &str = r"SET SHARDING KEY TO '[0-9]+';";
|
||||
const ROLE_REGEX: &str = r"SET SERVER ROLE TO '(PRIMARY|REPLICA)';";
|
||||
|
||||
/// The client state. One of these is created per client.
|
||||
pub struct Client {
|
||||
@@ -45,6 +47,9 @@ pub struct Client {
|
||||
|
||||
// sharding regex
|
||||
sharding_regex: Regex,
|
||||
|
||||
// role detection regex
|
||||
role_regex: Regex,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
@@ -57,6 +62,7 @@ impl Client {
|
||||
transaction_mode: bool,
|
||||
) -> Result<Client, Error> {
|
||||
let sharding_regex = Regex::new(SHARDING_REGEX).unwrap();
|
||||
let role_regex = Regex::new(ROLE_REGEX).unwrap();
|
||||
|
||||
loop {
|
||||
// Could be StartupMessage or SSLRequest
|
||||
@@ -114,6 +120,7 @@ impl Client {
|
||||
secret_key: secret_key,
|
||||
client_server_map: client_server_map,
|
||||
sharding_regex: sharding_regex,
|
||||
role_regex: role_regex,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -134,6 +141,7 @@ impl Client {
|
||||
secret_key: secret_key,
|
||||
client_server_map: client_server_map,
|
||||
sharding_regex: sharding_regex,
|
||||
role_regex: role_regex,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -172,6 +180,8 @@ impl Client {
|
||||
// - if in transaction mode, this lives for the duration of one transaction.
|
||||
let mut shard: Option<usize> = None;
|
||||
|
||||
let mut role: Option<Role> = None;
|
||||
|
||||
loop {
|
||||
// Read a complete message from the client, which normally would be
|
||||
// either a `Q` (query) or `P` (prepare, extended protocol).
|
||||
@@ -182,18 +192,29 @@ impl Client {
|
||||
|
||||
// Parse for special select shard command.
|
||||
// SET SHARDING KEY TO 'bigint';
|
||||
match self.select_shard(message.clone(), pool.shards()).await {
|
||||
match self.select_shard(message.clone(), pool.shards()) {
|
||||
Some(s) => {
|
||||
set_sharding_key(&mut self.write).await?;
|
||||
custom_protocol_response_ok(&mut self.write, "SET SHARDING KEY").await?;
|
||||
shard = Some(s);
|
||||
continue;
|
||||
}
|
||||
None => (),
|
||||
};
|
||||
|
||||
// Parse for special server role selection command.
|
||||
//
|
||||
match self.select_role(message.clone()) {
|
||||
Some(r) => {
|
||||
custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?;
|
||||
role = Some(r);
|
||||
continue;
|
||||
}
|
||||
None => (),
|
||||
};
|
||||
|
||||
// Grab a server from the pool.
|
||||
// None = any shard
|
||||
let connection = pool.get(shard).await.unwrap();
|
||||
let connection = pool.get(shard, role).await.unwrap();
|
||||
let mut proxy = connection.0;
|
||||
let _address = connection.1;
|
||||
let server = &mut *proxy;
|
||||
@@ -252,6 +273,7 @@ impl Client {
|
||||
// Release server
|
||||
if !server.in_transaction() && self.transaction_mode {
|
||||
shard = None;
|
||||
role = None;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -311,6 +333,7 @@ impl Client {
|
||||
// Release server
|
||||
if !server.in_transaction() && self.transaction_mode {
|
||||
shard = None;
|
||||
role = None;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -338,6 +361,7 @@ impl Client {
|
||||
if !server.in_transaction() && self.transaction_mode {
|
||||
println!("Releasing after copy done");
|
||||
shard = None;
|
||||
role = None;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -361,7 +385,7 @@ impl Client {
|
||||
/// Determine if the query is part of our special syntax, extract
|
||||
/// the shard key, and return the shard to query based on Postgres'
|
||||
/// PARTITION BY HASH function.
|
||||
async fn select_shard(&mut self, mut buf: BytesMut, shards: usize) -> Option<usize> {
|
||||
fn select_shard(&mut self, mut buf: BytesMut, shards: usize) -> Option<usize> {
|
||||
let code = buf.get_u8() as char;
|
||||
|
||||
// Only supporting simpe protocol here, so
|
||||
@@ -390,4 +414,31 @@ impl Client {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// Pick a primary or a replica from the pool.
|
||||
fn select_role(&mut self, mut buf: BytesMut) -> Option<Role> {
|
||||
let code = buf.get_u8() as char;
|
||||
|
||||
// Same story as select_shard() above.
|
||||
match code {
|
||||
'Q' => (),
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let len = buf.get_i32();
|
||||
let query = String::from_utf8_lossy(&buf[..len as usize - 4 - 1]).to_ascii_uppercase();
|
||||
|
||||
// Copy / paste from above. If we get one more of these use cases,
|
||||
// it'll be time to abstract :).
|
||||
if self.role_regex.is_match(&query) {
|
||||
let role = query.split("'").collect::<Vec<&str>>()[1];
|
||||
match role {
|
||||
"PRIMARY" => Some(Role::Primary),
|
||||
"REPLICA" => Some(Role::Replica),
|
||||
_ => return None,
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,10 +7,17 @@ use std::collections::HashMap;
|
||||
|
||||
use crate::errors::Error;
|
||||
|
||||
#[derive(Clone, PartialEq, Deserialize, Hash, std::cmp::Eq, Debug, Copy)]
|
||||
pub enum Role {
|
||||
Primary,
|
||||
Replica,
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Debug)]
|
||||
pub struct Address {
|
||||
pub host: String,
|
||||
pub port: String,
|
||||
pub role: Role,
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Deserialize, Debug)]
|
||||
@@ -32,7 +39,7 @@ pub struct General {
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct Shard {
|
||||
pub servers: Vec<(String, u16)>,
|
||||
pub servers: Vec<(String, u16, String)>,
|
||||
pub database: String,
|
||||
}
|
||||
|
||||
@@ -83,5 +90,6 @@ mod test {
|
||||
assert_eq!(config.general.pool_size, 15);
|
||||
assert_eq!(config.shards.len(), 3);
|
||||
assert_eq!(config.shards["1"].servers[0].0, "127.0.0.1");
|
||||
assert_eq!(config.shards["0"].servers[0].2, "primary");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -141,12 +141,16 @@ pub async fn md5_password(
|
||||
Ok(write_all(stream, message).await?)
|
||||
}
|
||||
|
||||
/// Implements a response to our custom `SET SHARDING KEY` command.
|
||||
/// Implements a response to our custom `SET SHARDING KEY`
|
||||
/// and `SET SERVER ROLE` commands.
|
||||
/// This tells the client we're ready for the next query.
|
||||
pub async fn set_sharding_key(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
|
||||
pub async fn custom_protocol_response_ok(
|
||||
stream: &mut OwnedWriteHalf,
|
||||
message: &str,
|
||||
) -> Result<(), Error> {
|
||||
let mut res = BytesMut::with_capacity(25);
|
||||
|
||||
let set_complete = BytesMut::from(&"SET SHARDING KEY\0"[..]);
|
||||
let set_complete = BytesMut::from(&format!("{}\0", message)[..]);
|
||||
let len = (set_complete.len() + 4) as i32;
|
||||
|
||||
// CommandComplete
|
||||
|
||||
28
src/pool.rs
28
src/pool.rs
@@ -3,7 +3,7 @@ use async_trait::async_trait;
|
||||
use bb8::{ManageConnection, Pool, PooledConnection};
|
||||
use chrono::naive::NaiveDateTime;
|
||||
|
||||
use crate::config::{Address, Config, User};
|
||||
use crate::config::{Address, Config, Role, User};
|
||||
use crate::errors::Error;
|
||||
use crate::server::Server;
|
||||
|
||||
@@ -48,9 +48,19 @@ impl ConnectionPool {
|
||||
let mut replica_addresses = Vec::new();
|
||||
|
||||
for server in &shard.servers {
|
||||
let role = match server.2.as_ref() {
|
||||
"primary" => Role::Primary,
|
||||
"replica" => Role::Replica,
|
||||
_ => {
|
||||
println!("> Config error: server role can be 'primary' or 'replica', have: '{}'. Defaulting to 'replica'.", server.2);
|
||||
Role::Replica
|
||||
}
|
||||
};
|
||||
|
||||
let address = Address {
|
||||
host: server.0.clone(),
|
||||
port: server.1.to_string(),
|
||||
role: role,
|
||||
};
|
||||
|
||||
let manager = ServerPool::new(
|
||||
@@ -93,6 +103,7 @@ impl ConnectionPool {
|
||||
pub async fn get(
|
||||
&self,
|
||||
shard: Option<usize>,
|
||||
role: Option<Role>,
|
||||
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
|
||||
// Set this to false to gain ~3-4% speed.
|
||||
let with_health_check = true;
|
||||
@@ -103,6 +114,9 @@ impl ConnectionPool {
|
||||
};
|
||||
|
||||
loop {
|
||||
// TODO: think about making this local, so multiple clients
|
||||
// don't compete for the same round-robin integer.
|
||||
// Especially since we're going to be skipping (see role selection below).
|
||||
let index =
|
||||
self.round_robin.fetch_add(1, Ordering::SeqCst) % self.databases[shard].len();
|
||||
let address = self.addresses[shard][index].clone();
|
||||
@@ -111,6 +125,17 @@ impl ConnectionPool {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Make sure you're getting a primary or a replica
|
||||
// as per request.
|
||||
match role {
|
||||
Some(role) => {
|
||||
if address.role != role {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
};
|
||||
|
||||
// Check if we can connect
|
||||
// TODO: implement query wait timeout, i.e. time to get a conn from the pool
|
||||
let mut conn = match self.databases[shard][index].get().await {
|
||||
@@ -251,6 +276,7 @@ impl ManageConnection for ServerPool {
|
||||
&self.user.password,
|
||||
&self.database,
|
||||
self.client_server_map.clone(),
|
||||
self.address.role,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ use tokio::io::{AsyncReadExt, BufReader};
|
||||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
use crate::config::Address;
|
||||
use crate::config::{Address, Role};
|
||||
use crate::errors::Error;
|
||||
use crate::messages::*;
|
||||
use crate::ClientServerMap;
|
||||
@@ -48,6 +48,8 @@ pub struct Server {
|
||||
|
||||
// Mapping of clients and servers used for query cancellation.
|
||||
client_server_map: ClientServerMap,
|
||||
|
||||
role: Role,
|
||||
}
|
||||
|
||||
impl Server {
|
||||
@@ -60,6 +62,7 @@ impl Server {
|
||||
password: &str,
|
||||
database: &str,
|
||||
client_server_map: ClientServerMap,
|
||||
role: Role,
|
||||
) -> Result<Server, Error> {
|
||||
let mut stream = match TcpStream::connect(&format!("{}:{}", host, port)).await {
|
||||
Ok(stream) => stream,
|
||||
@@ -189,6 +192,7 @@ impl Server {
|
||||
data_available: false,
|
||||
bad: false,
|
||||
client_server_map: client_server_map,
|
||||
role: role,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -409,6 +413,7 @@ impl Server {
|
||||
Address {
|
||||
host: self.host.to_string(),
|
||||
port: self.port.to_string(),
|
||||
role: self.role,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user