Update README. Comments. Version bump. (#60)

* update readme

* comments

* just a version bump
This commit is contained in:
Lev Kokotov
2022-03-10 01:33:29 -08:00
committed by GitHub
parent 509e4815a3
commit df85139281
15 changed files with 196 additions and 155 deletions

View File

@@ -1,8 +1,8 @@
/// Admin database.
use bytes::{Buf, BufMut, BytesMut};
use log::{info, trace};
use tokio::net::tcp::OwnedWriteHalf;
use std::collections::HashMap;
use tokio::net::tcp::OwnedWriteHalf;
use crate::config::{get_config, parse};
use crate::errors::Error;
@@ -10,7 +10,7 @@ use crate::messages::*;
use crate::pool::ConnectionPool;
use crate::stats::get_stats;
/// Handle admin client
/// Handle admin client.
pub async fn handle_admin(
stream: &mut OwnedWriteHalf,
mut query: BytesMut,
@@ -58,7 +58,7 @@ pub async fn handle_admin(
}
}
/// SHOW LISTS
/// Column-oriented statistics.
async fn show_lists(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> {
let stats = get_stats();
@@ -125,7 +125,7 @@ async fn show_lists(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Resul
write_all_half(stream, res).await
}
/// SHOW VERSION
/// Show PgCat version.
async fn show_version(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
let mut res = BytesMut::new();
@@ -140,7 +140,7 @@ async fn show_version(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
write_all_half(stream, res).await
}
/// SHOW POOLS
/// Show utilization of connection pools for each shard and replicas.
async fn show_pools(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> {
let stats = get_stats();
let config = {
@@ -189,6 +189,7 @@ async fn show_pools(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Resul
res.put(command_complete("SHOW"));
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
@@ -196,7 +197,7 @@ async fn show_pools(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Resul
write_all_half(stream, res).await
}
/// SHOW DATABASES
/// Show shards and replicas.
async fn show_databases(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> {
let guard = get_config();
let config = &*guard.clone();
@@ -221,7 +222,6 @@ async fn show_databases(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> R
let mut res = BytesMut::new();
// RowDescription
res.put(row_description(&columns));
for shard in 0..pool.shards() {
@@ -265,7 +265,7 @@ async fn ignore_set(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
custom_protocol_response_ok(stream, "SET").await
}
/// RELOAD
/// Reload the configuration file without restarting the process.
async fn reload(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
info!("Reloading config");
@@ -280,7 +280,6 @@ async fn reload(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
let mut res = BytesMut::new();
// CommandComplete
res.put(command_complete("RELOAD"));
// ReadyForQuery
@@ -291,13 +290,14 @@ async fn reload(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
write_all_half(stream, res).await
}
/// Shows current configuration.
async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
let guard = get_config();
let config = &*guard.clone();
let config: HashMap<String, String> = config.into();
drop(guard);
// Configs that cannot be changed dynamically.
// Configs that cannot be changed without restarting.
let immutables = ["host", "port", "connect_timeout"];
// Columns
@@ -327,6 +327,7 @@ async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
res.put(command_complete("SHOW"));
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
@@ -334,7 +335,7 @@ async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
write_all_half(stream, res).await
}
/// SHOW STATS
/// Show shard and replicas statistics.
async fn show_stats(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> {
let columns = vec![
("database", DataType::Text),
@@ -378,6 +379,7 @@ async fn show_stats(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Resul
res.put(command_complete("SHOW"));
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');

View File

@@ -1,16 +1,13 @@
/// Implementation of the PostgreSQL client.
/// We are pretending to the server in this scenario,
/// and this module implements that.
/// Handle clients by pretending to be a PostgreSQL server.
use bytes::{Buf, BufMut, BytesMut};
use log::{debug, error, trace};
use std::collections::HashMap;
use tokio::io::{AsyncReadExt, BufReader};
use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
};
use std::collections::HashMap;
use crate::admin::handle_admin;
use crate::config::get_config;
use crate::constants::*;
@@ -23,53 +20,52 @@ use crate::stats::Reporter;
/// The client state. One of these is created per client.
pub struct Client {
// The reads are buffered (8K by default).
/// The reads are buffered (8K by default).
read: BufReader<OwnedReadHalf>,
// We buffer the writes ourselves because we know the protocol
// better than a stock buffer.
/// We buffer the writes ourselves because we know the protocol
/// better than a stock buffer.
write: OwnedWriteHalf,
// Internal buffer, where we place messages until we have to flush
// them to the backend.
/// Internal buffer, where we place messages until we have to flush
/// them to the backend.
buffer: BytesMut,
// The client was started with the sole reason to cancel another running query.
/// The client was started with the sole reason to cancel another running query.
cancel_mode: bool,
// In transaction mode, the connection is released after each transaction.
// Session mode has slightly higher throughput per client, but lower capacity.
/// In transaction mode, the connection is released after each transaction.
/// Session mode has slightly higher throughput per client, but lower capacity.
transaction_mode: bool,
// For query cancellation, the client is given a random process ID and secret on startup.
/// For query cancellation, the client is given a random process ID and secret on startup.
process_id: i32,
secret_key: i32,
// Clients are mapped to servers while they use them. This allows a client
// to connect and cancel a query.
/// Clients are mapped to servers while they use them. This allows a client
/// to connect and cancel a query.
client_server_map: ClientServerMap,
// Client parameters, e.g. user, client_encoding, etc.
/// Client parameters, e.g. user, client_encoding, etc.
#[allow(dead_code)]
parameters: HashMap<String, String>,
// Statistics
/// Statistics
stats: Reporter,
// Clients want to talk to admin
/// Clients want to talk to admin database.
admin: bool,
// Last address the client talked to
/// Last address the client talked to.
last_address_id: Option<usize>,
// Last server process id we talked to
/// Last server process id we talked to.
last_server_id: Option<i32>,
}
impl Client {
/// Given a TCP socket, trick the client into thinking we are
/// the Postgres server. Perform the authentication and place
/// the client in query-ready mode.
/// Perform client startup sequence.
/// See docs: <https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3>
pub async fn startup(
mut stream: TcpStream,
client_server_map: ClientServerMap,
@@ -82,14 +78,12 @@ impl Client {
loop {
trace!("Waiting for StartupMessage");
// Could be StartupMessage or SSLRequest
// which makes this variable length.
// Could be StartupMessage, SSLRequest or CancelRequest.
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::ClientBadStartup),
};
// Read whatever is left.
let mut startup = vec![0u8; len as usize - 4];
match stream.read_exact(&mut startup).await {
@@ -189,7 +183,7 @@ impl Client {
}
}
/// Client loop. We handle all messages between the client and the database here.
/// Handle a connected and authenticated client.
pub async fn handle(&mut self, mut pool: ConnectionPool) -> Result<(), Error> {
// The client wants to cancel a query it has issued previously.
if self.cancel_mode {
@@ -225,14 +219,14 @@ impl Client {
// Our custom protocol loop.
// We expect the client to either start a transaction with regular queries
// or issue commands for our sharding and server selection protocols.
// or issue commands for our sharding and server selection protocol.
loop {
trace!("Client idle, waiting for message");
// Read a complete message from the client, which normally would be
// either a `Q` (query) or `P` (prepare, extended protocol).
// We can parse it here before grabbing a server from the pool,
// in case the client is sending some control messages, e.g.
// in case the client is sending some custom protocol messages, e.g.
// SET SHARDING KEY TO 'bigint';
let mut message = read_message(&mut self.read).await?;
@@ -242,43 +236,48 @@ impl Client {
return Ok(());
}
// Handle admin database real quick
// Handle admin database queries.
if self.admin {
trace!("Handling admin command");
handle_admin(&mut self.write, message, pool.clone()).await?;
continue;
}
// Handle all custom protocol commands here.
// Handle all custom protocol commands, if any.
match query_router.try_execute_command(message.clone()) {
// Normal query
// Normal query, not a custom command.
None => {
// Attempt to infer which server we want to query, i.e. primary or replica.
if query_router.query_parser_enabled() && query_router.role() == None {
query_router.infer_role(message.clone());
}
}
// SET SHARD TO
Some((Command::SetShard, _)) => {
custom_protocol_response_ok(&mut self.write, &format!("SET SHARD")).await?;
custom_protocol_response_ok(&mut self.write, "SET SHARD").await?;
continue;
}
// SET SHARDING KEY TO
Some((Command::SetShardingKey, _)) => {
custom_protocol_response_ok(&mut self.write, &format!("SET SHARDING KEY"))
.await?;
custom_protocol_response_ok(&mut self.write, "SET SHARDING KEY").await?;
continue;
}
// SET SERVER ROLE TO
Some((Command::SetServerRole, _)) => {
custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?;
continue;
}
// SHOW SERVER ROLE
Some((Command::ShowServerRole, value)) => {
show_response(&mut self.write, "server role", &value).await?;
continue;
}
// SHOW SHARD
Some((Command::ShowShard, value)) => {
show_response(&mut self.write, "shard", &value).await?;
continue;
@@ -290,7 +289,7 @@ impl Client {
error_response(
&mut self.write,
&format!(
"shard '{}' is more than configured '{}'",
"shard {} is more than configured {}",
query_router.shard(),
pool.shards()
),
@@ -301,7 +300,7 @@ impl Client {
debug!("Waiting for connection from pool");
// Grab a server from the pool: the client issued a regular query.
// Grab a server from the pool.
let connection = match pool
.get(query_router.shard(), query_router.role(), self.process_id)
.await
@@ -322,18 +321,18 @@ impl Client {
let address = connection.1;
let server = &mut *reference;
// Claim this server as mine for query cancellation.
// Server is assigned to the client in case the client wants to
// cancel a query later.
server.claim(self.process_id, self.secret_key);
// "disconnect" from the previous server stats-wise
// Update statistics.
if let Some(last_address_id) = self.last_address_id {
self.stats
.client_disconnecting(self.process_id, last_address_id);
}
// Client active & server active
self.stats.client_active(self.process_id, address.id);
self.stats.server_active(server.process_id(), address.id);
self.last_address_id = Some(address.id);
self.last_server_id = Some(server.process_id());
@@ -346,6 +345,9 @@ impl Client {
// Transaction loop. Multiple queries can be issued by the client here.
// The connection belongs to the client until the transaction is over,
// or until the client disconnects if we are in session mode.
//
// If the client is in session mode, no more custom protocol
// commands will be accepted.
loop {
let mut message = if message.len() == 0 {
trace!("Waiting for message inside transaction or in session mode");
@@ -353,10 +355,10 @@ impl Client {
match read_message(&mut self.read).await {
Ok(message) => message,
Err(err) => {
// Client disconnected without warning.
// Client disconnected inside a transaction.
// Clean up the server and re-use it.
// This prevents connection thrashing by bad clients.
if server.in_transaction() {
// Client left dirty server. Clean up and proceed
// without thrashing this connection.
server.query("ROLLBACK; DISCARD ALL;").await?;
}
@@ -383,13 +385,11 @@ impl Client {
'Q' => {
debug!("Sending query to server");
// TODO: implement retries here for read-only transactions.
server.send(original).await?;
// Read all data the server has to offer, which can be multiple messages
// buffered in 8196 bytes chunks.
loop {
// TODO: implement retries here for read-only transactions.
let response = server.recv().await?;
// Send server reply to the client.
@@ -409,7 +409,6 @@ impl Client {
// Report query executed statistics.
self.stats.query(self.process_id, address.id);
// The transaction is over, we can release the connection back to the pool.
if !server.in_transaction() {
// Report transaction executed statistics.
self.stats.transaction(self.process_id, address.id);
@@ -429,7 +428,6 @@ impl Client {
// connection before releasing into the pool.
// Pgbouncer closes the connection which leads to
// connection thrashing when clients misbehave.
// This pool will protect the database. :salute:
if server.in_transaction() {
server.query("ROLLBACK; DISCARD ALL;").await?;
}
@@ -468,7 +466,6 @@ impl Client {
self.buffer.put(&original[..]);
// TODO: retries for read-only transactions.
server.send(self.buffer.clone()).await?;
self.buffer.clear();
@@ -476,7 +473,6 @@ impl Client {
// Read all data the server has to offer, which can be multiple messages
// buffered in 8196 bytes chunks.
loop {
// TODO: retries for read-only transactions
let response = server.recv().await?;
match write_all_half(&mut self.write, response).await {
@@ -495,11 +491,11 @@ impl Client {
// Report query executed statistics.
self.stats.query(self.process_id, address.id);
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if !server.in_transaction() {
self.stats.transaction(self.process_id, address.id);
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
self.stats.server_idle(server.process_id(), address.id);
break;
@@ -529,11 +525,11 @@ impl Client {
}
};
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if !server.in_transaction() {
self.stats.transaction(self.process_id, address.id);
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
self.stats.server_idle(server.process_id(), address.id);
break;
@@ -556,7 +552,7 @@ impl Client {
}
}
/// Release the server from being mine. I can't cancel its queries anymore.
/// Release the server from the client: it can't cancel its queries anymore.
pub fn release(&self) {
let mut guard = self.client_server_map.lock();
guard.remove(&(self.process_id, self.secret_key));
@@ -565,11 +561,10 @@ impl Client {
impl Drop for Client {
fn drop(&mut self) {
// Disconnect the client
// Update statistics.
if let Some(address_id) = self.last_address_id {
self.stats.client_disconnecting(self.process_id, address_id);
// The server is now idle
if let Some(process_id) = self.last_server_id {
self.stats.server_idle(process_id, address_id);
}

View File

@@ -1,18 +1,20 @@
/// Parse the configuration file.
use arc_swap::{ArcSwap, Guard};
use log::{error, info};
use once_cell::sync::Lazy;
use serde_derive::Deserialize;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use toml;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::errors::Error;
/// Globally available configuration.
static CONFIG: Lazy<ArcSwap<Config>> = Lazy::new(|| ArcSwap::from_pointee(Config::default()));
/// Server role: primary or replica.
#[derive(Clone, PartialEq, Deserialize, Hash, std::cmp::Eq, Debug, Copy)]
pub enum Role {
Primary,
@@ -46,6 +48,7 @@ impl PartialEq<Role> for Option<Role> {
}
}
/// Address identifying a PostgreSQL server uniquely.
#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Debug)]
pub struct Address {
pub id: usize,
@@ -70,6 +73,7 @@ impl Default for Address {
}
impl Address {
/// Address name (aka database) used in `SHOW STATS`, `SHOW DATABASES`, and `SHOW POOLS`.
pub fn name(&self) -> String {
match self.role {
Role::Primary => format!("shard_{}_primary", self.shard),
@@ -79,6 +83,7 @@ impl Address {
}
}
/// PostgreSQL user.
#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Deserialize, Debug)]
pub struct User {
pub name: String,
@@ -94,6 +99,7 @@ impl Default for User {
}
}
/// General configuration.
#[derive(Deserialize, Debug, Clone)]
pub struct General {
pub host: String,
@@ -119,6 +125,7 @@ impl Default for General {
}
}
/// Shard configuration.
#[derive(Deserialize, Debug, Clone)]
pub struct Shard {
pub servers: Vec<(String, u16, String)>,
@@ -134,6 +141,7 @@ impl Default for Shard {
}
}
/// Query Router configuration.
#[derive(Deserialize, Debug, Clone)]
pub struct QueryRouter {
pub default_role: String,
@@ -153,6 +161,7 @@ impl Default for QueryRouter {
}
}
/// Configuration wrapper.
#[derive(Deserialize, Debug, Clone)]
pub struct Config {
pub path: Option<String>,
@@ -217,6 +226,7 @@ impl From<&Config> for std::collections::HashMap<String, String> {
}
impl Config {
/// Print current configuration.
pub fn show(&self) {
info!("Pool size: {}", self.general.pool_size);
info!("Pool mode: {}", self.general.pool_mode);
@@ -231,11 +241,14 @@ impl Config {
}
}
/// Get a read-only instance of the configuration
/// from anywhere in the app.
/// ArcSwap makes this cheap and quick.
pub fn get_config() -> Guard<Arc<Config>> {
CONFIG.load()
}
/// Parse the config.
/// Parse the configuration file located at the path.
pub async fn parse(path: &str) -> Result<(), Error> {
let mut contents = String::new();
let mut file = match File::open(path).await {
@@ -346,6 +359,7 @@ pub async fn parse(path: &str) -> Result<(), Error> {
config.path = Some(path.to_string());
// Update the configuration globally.
CONFIG.store(Arc::new(config.clone()));
Ok(())

View File

@@ -1,7 +1,6 @@
/// Various protocol constants, as defined in
/// https://www.postgresql.org/docs/12/protocol-message-formats.html
/// <https://www.postgresql.org/docs/12/protocol-message-formats.html>
/// and elsewhere in the source code.
/// Also other constants we use elsewhere.
// Used in the StartupMessage to indicate regular handshake.
pub const PROTOCOL_VERSION_NUMBER: i32 = 196608;

View File

@@ -1,12 +1,12 @@
/// Errors.
/// Various errors.
#[derive(Debug, PartialEq)]
pub enum Error {
SocketError,
// ClientDisconnected,
ClientBadStartup,
ProtocolSyncError,
ServerError,
// ServerTimeout,
// DirtyServer,
BadConfig,
AllServersDown,
}

View File

@@ -58,19 +58,15 @@ mod server;
mod sharding;
mod stats;
// Support for query cancellation: this maps our process_ids and
// secret keys to the backend's.
use config::get_config;
use pool::{ClientServerMap, ConnectionPool};
use stats::{Collector, Reporter};
/// Main!
#[tokio::main(worker_threads = 4)]
async fn main() {
env_logger::init();
info!("Welcome to PgCat! Meow.");
// Prepare regexes
if !query_router::QueryRouter::setup() {
error!("Could not setup query router");
return;
@@ -84,7 +80,6 @@ async fn main() {
String::from("pgcat.toml")
};
// Prepare the config
match config::parse(&config_file).await {
Ok(_) => (),
Err(err) => {
@@ -94,8 +89,8 @@ async fn main() {
};
let config = get_config();
let addr = format!("{}:{}", config.general.host, config.general.port);
let listener = match TcpListener::bind(&addr).await {
Ok(sock) => sock,
Err(err) => {
@@ -105,18 +100,20 @@ async fn main() {
};
info!("Running on {}", addr);
config.show();
// Tracks which client is connected to which server for query cancellation.
let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new()));
// Collect statistics and send them to StatsD
// Statistics reporting.
let (tx, rx) = mpsc::channel(100);
// Connection pool for all shards and replicas
// Connection pool that allows to query all shards and replicas.
let mut pool =
ConnectionPool::from_config(client_server_map.clone(), Reporter::new(tx.clone())).await;
// Statistics collector task.
let collector_tx = tx.clone();
let addresses = pool.databases();
tokio::task::spawn(async move {
@@ -135,7 +132,7 @@ async fn main() {
info!("Waiting for clients");
// Main app runs here.
// Client connection loop.
tokio::task::spawn(async move {
loop {
let pool = pool.clone();
@@ -151,7 +148,7 @@ async fn main() {
}
};
// Client goes to another thread, bye.
// Handle client.
tokio::task::spawn(async move {
let start = chrono::offset::Utc::now().naive_utc();
match client::Client::startup(socket, client_server_map, server_info, reporter)
@@ -185,7 +182,7 @@ async fn main() {
}
});
// Reload config
// Reload config:
// kill -SIGHUP $(pgrep pgcat)
tokio::task::spawn(async move {
let mut stream = unix_signal(SignalKind::hangup()).unwrap();
@@ -205,6 +202,7 @@ async fn main() {
}
});
// Exit on Ctrl-C (SIGINT) and SIGTERM.
let mut term_signal = unix_signal(SignalKind::terminate()).unwrap();
tokio::select! {

View File

@@ -222,7 +222,7 @@ pub async fn custom_protocol_response_ok(
/// Send a custom error message to the client.
/// Tell the client we are ready for the next query and no rollback is necessary.
/// Docs on error codes: https://www.postgresql.org/docs/12/errcodes-appendix.html
/// Docs on error codes: <https://www.postgresql.org/docs/12/errcodes-appendix.html>.
pub async fn error_response(stream: &mut OwnedWriteHalf, message: &str) -> Result<(), Error> {
let mut error = BytesMut::new();
@@ -339,6 +339,7 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
res
}
/// Create a DataRow message.
pub fn data_row(row: &Vec<String>) -> BytesMut {
let mut res = BytesMut::new();
let mut data_row = BytesMut::new();
@@ -358,6 +359,7 @@ pub fn data_row(row: &Vec<String>) -> BytesMut {
res
}
/// Create a CommandComplete message.
pub fn command_complete(command: &str) -> BytesMut {
let cmd = BytesMut::from(format!("{}\0", command).as_bytes());
let mut res = BytesMut::new();

View File

@@ -1,24 +1,23 @@
/// Pooling and failover and banlist.
/// Pooling, failover and banlist.
use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection};
use bytes::BytesMut;
use chrono::naive::NaiveDateTime;
use log::{debug, error, info, warn};
use parking_lot::{Mutex, RwLock};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use crate::config::{get_config, Address, Role, User};
use crate::errors::Error;
use crate::server::Server;
use crate::stats::Reporter;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
// Banlist: bad servers go in here.
pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>;
/// The globally accessible connection pool.
#[derive(Clone, Debug)]
pub struct ConnectionPool {
databases: Vec<Vec<Pool<ServerPool>>>,
@@ -29,7 +28,7 @@ pub struct ConnectionPool {
}
impl ConnectionPool {
/// Construct the connection pool from a config file.
/// Construct the connection pool from the configuration.
pub async fn from_config(
client_server_map: ClientServerMap,
stats: Reporter,
@@ -204,10 +203,9 @@ impl ConnectionPool {
}
while allowed_attempts > 0 {
// Round-robin each client's queries.
// If a client only sends one query and then disconnects, it doesn't matter
// which replica it'll go to.
// Round-robin replicas.
self.round_robin += 1;
let index = self.round_robin % addresses.len();
let address = &addresses[index];
@@ -239,7 +237,7 @@ impl ConnectionPool {
}
};
// // Check if this server is alive with a health check
// // Check if this server is alive with a health check.
let server = &mut *conn;
let healthcheck_timeout = get_config().general.healthcheck_timeout;
@@ -251,7 +249,7 @@ impl ConnectionPool {
)
.await
{
// Check if health check succeeded
// Check if health check succeeded.
Ok(res) => match res {
Ok(_) => {
self.stats
@@ -259,8 +257,11 @@ impl ConnectionPool {
self.stats.server_idle(conn.process_id(), address.id);
return Ok((conn, address.clone()));
}
// Health check failed.
Err(_) => {
error!("Banning replica {} because of failed health check", index);
// Don't leave a bad connection in the pool.
server.mark_bad();
@@ -271,7 +272,8 @@ impl ConnectionPool {
continue;
}
},
// Health check never came back, database is really really down
// Health check timed out.
Err(_) => {
error!("Banning replica {} because of health check timeout", index);
// Don't leave a bad connection in the pool.
@@ -358,14 +360,18 @@ impl ConnectionPool {
}
}
/// Get the number of configured shards.
pub fn shards(&self) -> usize {
self.databases.len()
}
/// Get the number of servers (primary and replicas)
/// configured for a shard.
pub fn servers(&self, shard: usize) -> usize {
self.addresses[shard].len()
}
/// Get the total number of servers (databases) we are connected to.
pub fn databases(&self) -> usize {
let mut databases = 0;
for shard in 0..self.shards() {
@@ -374,15 +380,18 @@ impl ConnectionPool {
databases
}
/// Get pool state for a particular shard server as reported by bb8.
pub fn pool_state(&self, shard: usize, server: usize) -> bb8::State {
self.databases[shard][server].state()
}
/// Get the address information for a shard server.
pub fn address(&self, shard: usize, server: usize) -> &Address {
&self.addresses[shard][server]
}
}
/// Wrapper for the bb8 connection pool.
pub struct ServerPool {
address: Address,
user: User,
@@ -427,6 +436,7 @@ impl ManageConnection for ServerPool {
let process_id = rand::random::<i32>();
self.stats.server_login(process_id, self.address.id);
// Connect to the PostgreSQL server.
match Server::startup(
&self.address,
&self.user,

View File

@@ -1,5 +1,3 @@
use crate::config::{get_config, Role};
use crate::sharding::{Sharder, ShardingFunction};
/// Route queries automatically based on explicitely requested
/// or implied query characteristics.
use bytes::{Buf, BytesMut};
@@ -10,6 +8,10 @@ use sqlparser::ast::Statement::{Query, StartTransaction};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
use crate::config::{get_config, Role};
use crate::sharding::{Sharder, ShardingFunction};
/// Regexes used to parse custom commands.
const CUSTOM_SQL_REGEXES: [&str; 5] = [
r"(?i)^ *SET SHARDING KEY TO '?([0-9]+)'? *;? *$",
r"(?i)^ *SET SHARD TO '?([0-9]+|ANY)'? *;? *$",
@@ -18,6 +20,7 @@ const CUSTOM_SQL_REGEXES: [&str; 5] = [
r"(?i)^ *SHOW SERVER ROLE *;? *$",
];
/// Custom commands.
#[derive(PartialEq, Debug)]
pub enum Command {
SetShardingKey,
@@ -27,37 +30,39 @@ pub enum Command {
ShowServerRole,
}
// Quick test
/// Quickly test for match when a query is received.
static CUSTOM_SQL_REGEX_SET: OnceCell<RegexSet> = OnceCell::new();
// Capture value
// Get the value inside the custom command.
static CUSTOM_SQL_REGEX_LIST: OnceCell<Vec<Regex>> = OnceCell::new();
/// The query router.
pub struct QueryRouter {
// By default, queries go here, unless we have better information
// about what the client wants.
/// By default, queries go here, unless we have better information
/// about what the client wants.
default_server_role: Option<Role>,
// Number of shards in the cluster.
/// Number of shards in the cluster.
shards: usize,
// Which shard we should be talking to right now.
/// Which shard we should be talking to right now.
active_shard: Option<usize>,
// Should we be talking to a primary or a replica?
/// Which server should we be talking to.
active_role: Option<Role>,
// Include the primary into the replica pool?
/// Include the primary into the replica pool for reads.
primary_reads_enabled: bool,
// Should we try to parse queries?
/// Should we try to parse queries to route them to replicas or primary automatically.
query_parser_enabled: bool,
// Which sharding function are we using?
/// Which sharding function we're using.
sharding_function: ShardingFunction,
}
impl QueryRouter {
/// One-time initialization of regexes.
pub fn setup() -> bool {
let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) {
Ok(rgx) => rgx,
@@ -88,6 +93,7 @@ impl QueryRouter {
}
}
/// Create a new instance of the query router. Each client gets its own.
pub fn new() -> QueryRouter {
let config = get_config();
@@ -120,6 +126,7 @@ impl QueryRouter {
pub fn try_execute_command(&mut self, mut buf: BytesMut) -> Option<(Command, String)> {
let code = buf.get_u8() as char;
// Only simple protocol supported for commands.
if code != 'Q' {
return None;
}
@@ -158,8 +165,7 @@ impl QueryRouter {
// figured out a better way just yet. I think I can write a single Regex
// that matches all 5 custom SQL patterns, but maybe that's not very legible?
//
// I think this is faster than running the Regex engine 5 times, so
// this is a strong maybe for me so far.
// I think this is faster than running the Regex engine 5 times.
match regex_list[matches[0]].captures(&query) {
Some(captures) => match captures.get(1) {
Some(value) => value.as_str().to_string(),
@@ -221,7 +227,6 @@ impl QueryRouter {
}
"default" => {
// TODO: reset query parser to default here.
self.active_role = self.default_server_role;
self.query_parser_enabled = get_config().query_router.query_parser_enabled;
self.active_role
@@ -243,12 +248,14 @@ impl QueryRouter {
let len = buf.get_i32() as usize;
let query = match code {
// Query
'Q' => {
let query = String::from_utf8_lossy(&buf[..len - 5]).to_string();
debug!("Query: '{}'", query);
query
}
// Parse (prepared statement)
'P' => {
let mut start = 0;
let mut end;
@@ -271,6 +278,7 @@ impl QueryRouter {
query.replace("$", "") // Remove placeholders turning them into "values"
}
_ => return false,
};
@@ -334,6 +342,7 @@ impl QueryRouter {
self.query_parser_enabled
}
/// Allows to toggle primary reads in tests.
#[allow(dead_code)]
pub fn toggle_primary_reads(&mut self, value: bool) {
self.primary_reads_enabled = value;

View File

@@ -1,6 +1,6 @@
/// Implementation of the PostgreSQL server (database) protocol.
/// Here we are pretending to the a Postgres client.
use bytes::{Buf, BufMut, BytesMut};
///! Implementation of the PostgreSQL server (database) protocol.
///! Here we are pretending to the a Postgres client.
use log::{debug, error, info, trace};
use tokio::io::{AsyncReadExt, BufReader};
use tokio::net::{
@@ -17,42 +17,42 @@ use crate::ClientServerMap;
/// Server state.
pub struct Server {
// Server host, e.g. localhost,
// port, e.g. 5432, and role, e.g. primary or replica.
/// Server host, e.g. localhost,
/// port, e.g. 5432, and role, e.g. primary or replica.
address: Address,
// Buffered read socket.
/// Buffered read socket.
read: BufReader<OwnedReadHalf>,
// Unbuffered write socket (our client code buffers).
/// Unbuffered write socket (our client code buffers).
write: OwnedWriteHalf,
// Our server response buffer. We buffer data before we give it to the client.
/// Our server response buffer. We buffer data before we give it to the client.
buffer: BytesMut,
// Server information the server sent us over on startup.
/// Server information the server sent us over on startup.
server_info: BytesMut,
// Backend id and secret key used for query cancellation.
/// Backend id and secret key used for query cancellation.
process_id: i32,
secret_key: i32,
// Is the server inside a transaction or idle.
/// Is the server inside a transaction or idle.
in_transaction: bool,
// Is there more data for the client to read.
/// Is there more data for the client to read.
data_available: bool,
// Is the server broken? We'll remote it from the pool if so.
/// Is the server broken? We'll remote it from the pool if so.
bad: bool,
// Mapping of clients and servers used for query cancellation.
/// Mapping of clients and servers used for query cancellation.
client_server_map: ClientServerMap,
// Server connected at.
/// Server connected at.
connected_at: chrono::naive::NaiveDateTime,
// Reports various metrics, e.g. data sent & received.
/// Reports various metrics, e.g. data sent & received.
stats: Reporter,
}
@@ -77,7 +77,7 @@ impl Server {
trace!("Sending StartupMessage");
// Send the startup packet telling the server we're a normal Postgres client.
// StartupMessage
startup(&mut stream, &user.name, database).await?;
let mut server_info = BytesMut::new();
@@ -187,7 +187,7 @@ impl Server {
// BackendKeyData
'K' => {
// The frontend must save these values if it wishes to be able to issue CancelRequest messages later.
// See: https://www.postgresql.org/docs/12/protocol-message-formats.html
// See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>.
process_id = match stream.read_i32().await {
Ok(id) => id,
Err(_) => return Err(Error::SocketError),
@@ -208,8 +208,6 @@ impl Server {
Err(_) => return Err(Error::SocketError),
};
// This is the last step in the client-server connection setup,
// and indicates the server is ready for to query it.
let (read, write) = stream.into_split();
return Ok(Server {
@@ -342,8 +340,7 @@ impl Server {
// More data is available after this message, this is not the end of the reply.
self.data_available = true;
// Don't flush yet, the more we buffer, the faster this goes...
// up to a limit of course.
// Don't flush yet, the more we buffer, the faster this goes...up to a limit.
if self.buffer.len() >= 8196 {
break;
}
@@ -411,7 +408,7 @@ impl Server {
/// Indicate that this server connection cannot be re-used and must be discarded.
pub fn mark_bad(&mut self) {
error!("Server marked bad");
error!("Server {:?} marked bad", self.address);
self.bad = true;
}
@@ -462,6 +459,7 @@ impl Server {
self.address.clone()
}
/// Get the server's unique identifier.
pub fn process_id(&self) -> i32 {
self.process_id
}
@@ -481,9 +479,10 @@ impl Drop for Server {
match self.write.try_write(&bytes) {
Ok(_) => (),
Err(_) => (),
Err(_) => debug!("Dirty shutdown"),
};
// Should not matter.
self.bad = true;
let now = chrono::offset::Utc::now().naive_utc();

View File

@@ -1,20 +1,27 @@
/// Implements various sharding functions.
use sha1::{Digest, Sha1};
// https://github.com/postgres/postgres/blob/27b77ecf9f4d5be211900eda54d8155ada50d696/src/include/catalog/partition.h#L20
/// See: <https://github.com/postgres/postgres/blob/27b77ecf9f4d5be211900eda54d8155ada50d696/src/include/catalog/partition.h#L20>.
const PARTITION_HASH_SEED: u64 = 0x7A5B22367996DCFD;
/// The sharding functions we support.
#[derive(Debug, PartialEq, Copy, Clone)]
pub enum ShardingFunction {
PgBigintHash,
Sha1,
}
/// The sharder.
pub struct Sharder {
/// Number of shards in the cluster.
shards: usize,
/// The sharding function in use.
sharding_function: ShardingFunction,
}
impl Sharder {
/// Create new instance of the sharder.
pub fn new(shards: usize, sharding_function: ShardingFunction) -> Sharder {
Sharder {
shards,
@@ -22,6 +29,7 @@ impl Sharder {
}
}
/// Compute the shard given sharding key.
pub fn shard(&self, key: i64) -> usize {
match self.sharding_function {
ShardingFunction::PgBigintHash => self.pg_bigint_hash(key),
@@ -31,7 +39,7 @@ impl Sharder {
/// Hash function used by Postgres to determine which partition
/// to put the row in when using HASH(column) partitioning.
/// Source: https://github.com/postgres/postgres/blob/27b77ecf9f4d5be211900eda54d8155ada50d696/src/common/hashfn.c#L631
/// Source: <https://github.com/postgres/postgres/blob/27b77ecf9f4d5be211900eda54d8155ada50d696/src/common/hashfn.c#L631>.
/// Supports only 1 bigint at the moment, but we can add more later.
fn pg_bigint_hash(&self, key: i64) -> usize {
let mut lohalf = key as u32;
@@ -119,6 +127,7 @@ impl Sharder {
a
}
#[inline]
fn pg_u32_hash(k: u32) -> u64 {
let mut a: u32 = 0x9e3779b9 as u32 + std::mem::size_of::<u32>() as u32 + 3923095 as u32;
let mut b = a;

View File

@@ -5,12 +5,12 @@ use parking_lot::Mutex;
use std::collections::HashMap;
use tokio::sync::mpsc::{Receiver, Sender};
// Latest stats updated every second; used in SHOW STATS and other admin commands.
/// Latest stats updated every second; used in SHOW STATS and other admin commands.
static LATEST_STATS: Lazy<Mutex<HashMap<usize, HashMap<String, i64>>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
// Statistics period used for average calculations.
// 15 seconds.
/// Statistics period used for average calculations.
/// 15 seconds.
static STAT_PERIOD: u64 = 15000;
/// The names for the events reported