Compare commits

..

2 Commits

Author SHA1 Message Date
Lev Kokotov
bca5318d5e Fix broken custom config 2022-09-12 15:58:11 -04:00
Lev Kokotov
efd6b2edae Automatic shard detection 2022-09-12 15:07:10 -04:00
5 changed files with 232 additions and 203 deletions

View File

@@ -82,6 +82,7 @@ primary_reads_enabled = true
# sha1: A hashing function based on SHA1
#
sharding_function = "pg_bigint_hash"
sharding_key = "id"
# Credentials for users that may connect to this cluster
[pools.sharded_db.users.0]

View File

@@ -88,9 +88,6 @@ pub struct Client<S, T> {
/// Used to notify clients about an impending shutdown
shutdown: Receiver<()>,
// Sharding key column position
sharding_key_column: Option<usize>,
}
/// Client entrypoint.
@@ -508,7 +505,6 @@ where
application_name: application_name.to_string(),
shutdown,
connected_to_server: false,
sharding_key_column: None,
});
}
@@ -543,7 +539,6 @@ where
application_name: String::from("undefined"),
shutdown,
connected_to_server: false,
sharding_key_column: None,
});
}
@@ -667,7 +662,7 @@ where
// Normal query, not a custom command.
None => {
if query_router.query_parser_enabled() {
query_router.infer_role(message.clone());
query_router.infer_role_and_shard(message.clone());
}
}
@@ -729,13 +724,6 @@ where
show_response(&mut self.write, "primary reads", &value).await?;
continue;
}
// COPY .. SHARDING_KEY_COLUMN ..
Some((Command::StartShardedCopy, value)) => {
custom_protocol_response_ok(&mut self.write, "SHARDED_COPY").await?;
self.sharding_key_column = Some(value.parse::<usize>().unwrap());
continue;
}
};
debug!("Waiting for connection from pool");
@@ -804,7 +792,7 @@ where
// If the client is in session mode, no more custom protocol
// commands will be accepted.
loop {
let message = if message.len() == 0 {
let mut message = if message.len() == 0 {
trace!("Waiting for message inside transaction or in session mode");
match read_message(&mut self.read).await {
@@ -823,11 +811,154 @@ where
msg
};
match self.handle_message(&pool, server, &address, message).await? {
Some(done) => if done { break; },
None => return Ok(()),
};
// The message will be forwarded to the server intact. We still would like to
// parse it below to figure out what to do with it.
let original = message.clone();
let code = message.get_u8() as char;
let _len = message.get_i32() as usize;
trace!("Message: {}", code);
match code {
// ReadyForQuery
'Q' => {
debug!("Sending query to server");
self.send_and_receive_loop(code, original, server, &address, &pool)
.await?;
if !server.in_transaction() {
// Report transaction executed statistics.
self.stats.transaction(self.process_id, address.id);
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
break;
}
}
}
// Terminate
'X' => {
server.checkin_cleanup().await?;
self.release();
return Ok(());
}
// Parse
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
'P' => {
self.buffer.put(&original[..]);
}
// Bind
// The placeholder's replacements are here, e.g. 'user@email.com' and 'true'
'B' => {
self.buffer.put(&original[..]);
}
// Describe
// Command a client can issue to describe a previously prepared named statement.
'D' => {
self.buffer.put(&original[..]);
}
// Execute
// Execute a prepared statement prepared in `P` and bound in `B`.
'E' => {
self.buffer.put(&original[..]);
}
// Sync
// Frontend (client) is asking for the query result now.
'S' => {
debug!("Sending query to server");
self.buffer.put(&original[..]);
// Clone after freeze does not allocate
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
// Almost certainly true
if first_message_code == 'P' {
// Message layout
// P followed by 32 int followed by null-terminated statement name
// So message code should be in offset 0 of the buffer, first character
// in prepared statement name would be index 5
let first_char_in_name = *self.buffer.get(5).unwrap_or(&0);
if first_char_in_name != 0 {
// This is a named prepared statement
// Server connection state will need to be cleared at checkin
server.mark_dirty();
}
}
self.send_and_receive_loop(
code,
self.buffer.clone(),
server,
&address,
&pool,
)
.await?;
self.buffer.clear();
if !server.in_transaction() {
self.stats.transaction(self.process_id, address.id);
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
break;
}
}
}
// CopyData
'd' => {
// Forward the data to the server,
// don't buffer it since it can be rather large.
self.send_server_message(server, original, &address, &pool)
.await?;
}
// CopyDone or CopyFail
// Copy is done, successfully or not.
'c' | 'f' => {
self.send_server_message(server, original, &address, &pool)
.await?;
let response = self.receive_server_message(server, &address, &pool).await?;
match write_all_half(&mut self.write, response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
return Err(err);
}
};
if !server.in_transaction() {
self.stats.transaction(self.process_id, address.id);
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
break;
}
}
}
// Some unexpected message. We either did not implement the protocol correctly
// or this is not a Postgres client we're talking to.
_ => {
error!("Unexpected code: {}", code);
}
}
}
// The server is no longer bound to us, we can't cancel it's queries anymore.
@@ -841,156 +972,6 @@ where
}
}
async fn handle_message(&mut self, pool: &ConnectionPool, server: &mut Server, address: &Address, mut message: BytesMut) -> Result<Option<bool>, Error> {
let original = message.clone();
let code = message.get_u8() as char;
let _len = message.get_i32() as usize;
trace!("Message: {}", code);
match code {
// ReadyForQuery
'Q' => {
debug!("Sending query to server");
self.send_and_receive_loop(code, original, server, &address, &pool)
.await?;
if !server.in_transaction() {
// Report transaction executed statistics.
self.stats.transaction(self.process_id, address.id);
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
return Ok(Some(true));
}
}
}
// Terminate
'X' => {
server.checkin_cleanup().await?;
self.release();
return Ok(None);
}
// Parse
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
'P' => {
self.buffer.put(&original[..]);
}
// Bind
// The placeholder's replacements are here, e.g. 'user@email.com' and 'true'
'B' => {
self.buffer.put(&original[..]);
}
// Describe
// Command a client can issue to describe a previously prepared named statement.
'D' => {
self.buffer.put(&original[..]);
}
// Execute
// Execute a prepared statement prepared in `P` and bound in `B`.
'E' => {
self.buffer.put(&original[..]);
}
// Sync
// Frontend (client) is asking for the query result now.
'S' => {
debug!("Sending query to server");
self.buffer.put(&original[..]);
// Clone after freeze does not allocate
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
// Almost certainly true
if first_message_code == 'P' {
// Message layout
// P followed by 32 int followed by null-terminated statement name
// So message code should be in offset 0 of the buffer, first character
// in prepared statement name would be index 5
let first_char_in_name = *self.buffer.get(5).unwrap_or(&0);
if first_char_in_name != 0 {
// This is a named prepared statement
// Server connection state will need to be cleared at checkin
server.mark_dirty();
}
}
self.send_and_receive_loop(
code,
self.buffer.clone(),
server,
&address,
&pool,
)
.await?;
self.buffer.clear();
if !server.in_transaction() {
self.stats.transaction(self.process_id, address.id);
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
return Ok(Some(true));
}
}
}
// CopyData
'd' => {
// Forward the data to the server,
// don't buffer it since it can be rather large.
self.send_server_message(server, original, &address, &pool)
.await?;
}
// CopyDone or CopyFail
// Copy is done, successfully or not.
'c' | 'f' => {
self.send_server_message(server, original, &address, &pool)
.await?;
let response = self.receive_server_message(server, &address, &pool).await?;
match write_all_half(&mut self.write, response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
return Err(err);
}
};
if !server.in_transaction() {
self.stats.transaction(self.process_id, address.id);
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
return Ok(Some(true));
}
}
}
// Some unexpected message. We either did not implement the protocol correctly
// or this is not a Postgres client we're talking to.
_ => {
error!("Unexpected code: {}", code);
}
};
Ok(Some(false))
}
/// Release the server from the client: it can't cancel its queries anymore.
pub fn release(&self) {
let mut guard = self.client_server_map.lock();

View File

@@ -185,6 +185,7 @@ pub struct Pool {
pub query_parser_enabled: bool,
pub primary_reads_enabled: bool,
pub sharding_function: String,
pub sharding_key: Option<String>,
pub shards: HashMap<String, Shard>,
pub users: HashMap<String, User>,
}
@@ -198,6 +199,7 @@ impl Default for Pool {
query_parser_enabled: false,
primary_reads_enabled: true,
sharding_function: "pg_bigint_hash".to_string(),
sharding_key: None,
}
}
}

View File

@@ -8,6 +8,7 @@ use once_cell::sync::Lazy;
use parking_lot::{Mutex, RwLock};
use rand::seq::SliceRandom;
use rand::thread_rng;
use regex::Regex;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
@@ -68,6 +69,9 @@ pub struct PoolSettings {
// Sharding function.
pub sharding_function: ShardingFunction,
// Automatically detect sharding key in query.
pub sharding_key_regex: Option<Regex>,
}
impl Default for PoolSettings {
@@ -80,6 +84,7 @@ impl Default for PoolSettings {
query_parser_enabled: false,
primary_reads_enabled: true,
sharding_function: ShardingFunction::PgBigintHash,
sharding_key_regex: None,
}
}
}
@@ -229,6 +234,20 @@ impl ConnectionPool {
"sha1" => ShardingFunction::Sha1,
_ => unreachable!(),
},
sharding_key_regex: match &pool_config.sharding_key {
Some(sharding_key) => match Regex::new(&format!(
r"(?i) *{} *= *'?([0-9]+)'?",
sharding_key
)) {
Ok(regex) => Some(regex),
Err(err) => {
error!("Sharding key regex error: {:?}", err);
return Err(Error::BadConfig);
}
},
None => None,
},
},
};

View File

@@ -13,7 +13,7 @@ use crate::pool::PoolSettings;
use crate::sharding::Sharder;
/// Regexes used to parse custom commands.
const CUSTOM_SQL_REGEXES: [&str; 8] = [
const CUSTOM_SQL_REGEXES: [&str; 7] = [
r"(?i)^ *SET SHARDING KEY TO '?([0-9]+)'? *;? *$",
r"(?i)^ *SET SHARD TO '?([0-9]+|ANY)'? *;? *$",
r"(?i)^ *SHOW SHARD *;? *$",
@@ -21,7 +21,6 @@ const CUSTOM_SQL_REGEXES: [&str; 8] = [
r"(?i)^ *SHOW SERVER ROLE *;? *$",
r"(?i)^ *SET PRIMARY READS TO '?(on|off|default)'? *;? *$",
r"(?i)^ *SHOW PRIMARY READS *;? *$",
r"(?i)^ *SHARDED_COPY '?([0-9]+)'? *;? *$",
];
/// Custom commands.
@@ -34,7 +33,6 @@ pub enum Command {
ShowServerRole,
SetPrimaryReads,
ShowPrimaryReads,
StartShardedCopy,
}
/// Quickly test for match when a query is received.
@@ -57,11 +55,10 @@ pub struct QueryRouter {
/// Include the primary into the replica pool for reads.
primary_reads_enabled: bool,
set_manually: bool,
/// Pool configuration.
pool_settings: PoolSettings,
// Sharding key column
sharding_key_column: Option<usize>,
}
impl QueryRouter {
@@ -102,14 +99,19 @@ impl QueryRouter {
active_role: None,
query_parser_enabled: false,
primary_reads_enabled: false,
set_manually: false,
pool_settings: PoolSettings::default(),
sharding_key_column: None,
}
}
/// Pool settings can change because of a config reload.
pub fn update_pool_settings(&mut self, pool_settings: PoolSettings) {
self.pool_settings = pool_settings;
if !self.set_manually {
self.query_parser_enabled = self.pool_settings.query_parser_enabled;
self.primary_reads_enabled = self.pool_settings.primary_reads_enabled;
}
}
/// Try to parse a command and execute it.
@@ -151,7 +153,6 @@ impl QueryRouter {
4 => Command::ShowServerRole,
5 => Command::SetPrimaryReads,
6 => Command::ShowPrimaryReads,
7 => Command::StartShardedCopy,
_ => unreachable!(),
};
@@ -159,8 +160,7 @@ impl QueryRouter {
Command::SetShardingKey
| Command::SetShard
| Command::SetServerRole
| Command::SetPrimaryReads
| Command::StartShardedCopy => {
| Command::SetPrimaryReads => {
// Capture value. I know this re-runs the regex engine, but I haven't
// figured out a better way just yet. I think I can write a single Regex
// that matches all 5 custom SQL patterns, but maybe that's not very legible?
@@ -212,14 +212,9 @@ impl QueryRouter {
};
}
Command::StartShardedCopy => {
self.sharding_key_column = match value.parse::<usize>() {
Ok(value) => Some(value),
Err(_) => return None,
}
}
Command::SetServerRole => {
self.set_manually = true;
self.active_role = match value.to_ascii_lowercase().as_ref() {
"primary" => {
self.query_parser_enabled = false;
@@ -243,7 +238,7 @@ impl QueryRouter {
"default" => {
self.active_role = self.pool_settings.default_role;
self.query_parser_enabled = self.query_parser_enabled;
self.query_parser_enabled = self.pool_settings.query_parser_enabled;
self.active_role
}
@@ -252,6 +247,8 @@ impl QueryRouter {
}
Command::SetPrimaryReads => {
self.set_manually = true;
if value == "on" {
debug!("Setting primary reads to on");
self.primary_reads_enabled = true;
@@ -271,7 +268,7 @@ impl QueryRouter {
}
/// Try to infer which server to connect to based on the contents of the query.
pub fn infer_role(&mut self, mut buf: BytesMut) -> bool {
pub fn infer_role_and_shard(&mut self, mut buf: BytesMut) -> bool {
debug!("Inferring role");
let code = buf.get_u8() as char;
@@ -312,6 +309,31 @@ impl QueryRouter {
_ => return false,
};
// First find the shard key
match &self.pool_settings.sharding_key_regex {
Some(re) => {
match re.captures(&query) {
Some(group) => match group.get(1) {
Some(value) => {
let value = value.as_str().parse::<i64>().unwrap();
let sharder = Sharder::new(
self.pool_settings.shards,
self.pool_settings.sharding_function,
);
let shard = sharder.shard(value);
self.active_shard = Some(shard);
debug!("Automatically routing to shard {}", shard);
}
None => (),
},
None => (),
};
}
None => (),
};
let ast = match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
Ok(ast) => ast,
Err(err) => {
@@ -388,7 +410,7 @@ mod test {
}
#[test]
fn test_infer_role_replica() {
fn test_infer_role_and_shard_replica() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None);
@@ -397,22 +419,25 @@ mod test {
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);
let queries = vec![
simple_query("SELECT * FROM items WHERE id = 5"),
simple_query("SELECT * FROM items WHERE id = 4"),
simple_query(
"SELECT id, name, value FROM items INNER JOIN prices ON item.id = prices.item_id",
),
simple_query("WITH t AS (SELECT * FROM items) SELECT * FROM t"),
];
for query in queries {
let shards = vec![0, 0, 0];
for (idx, query) in queries.iter().enumerate() {
// It's a recognized query
assert!(qr.infer_role(query));
assert!(qr.infer_role_and_shard(query.clone()));
assert_eq!(qr.role(), Some(Role::Replica));
assert_eq!(qr.shard(), shards[idx]);
}
}
#[test]
fn test_infer_role_primary() {
fn test_infer_role_and_shard_primary() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
@@ -425,24 +450,24 @@ mod test {
for query in queries {
// It's a recognized query
assert!(qr.infer_role(query));
assert!(qr.infer_role_and_shard(query));
assert_eq!(qr.role(), Some(Role::Primary));
}
}
#[test]
fn test_infer_role_primary_reads_enabled() {
fn test_infer_role_and_shard_primary_reads_enabled() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None);
assert!(qr.infer_role(query));
assert!(qr.infer_role_and_shard(query));
assert_eq!(qr.role(), None);
}
#[test]
fn test_infer_role_parse_prepared() {
fn test_infer_role_and_shard_parse_prepared() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'"));
@@ -457,7 +482,7 @@ mod test {
res.put(prepared_stmt);
res.put_i16(0);
assert!(qr.infer_role(res));
assert!(qr.infer_role_and_shard(res));
assert_eq!(qr.role(), Some(Role::Replica));
}
@@ -621,17 +646,17 @@ mod test {
assert_eq!(qr.role(), None);
let query = simple_query("INSERT INTO test_table VALUES (1)");
assert_eq!(qr.infer_role(query), true);
assert_eq!(qr.infer_role_and_shard(query), true);
assert_eq!(qr.role(), Some(Role::Primary));
let query = simple_query("SELECT * FROM test_table");
assert_eq!(qr.infer_role(query), true);
assert_eq!(qr.infer_role_and_shard(query), true);
assert_eq!(qr.role(), Some(Role::Replica));
assert!(qr.query_parser_enabled());
let query = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(query) != None);
assert!(qr.query_parser_enabled());
assert!(!qr.query_parser_enabled());
}
#[test]
@@ -644,7 +669,8 @@ mod test {
user: crate::config::User::default(),
default_role: Some(Role::Replica),
query_parser_enabled: true,
primary_reads_enabled: false,
primary_reads_enabled: true,
sharding_key_regex: None,
sharding_function: ShardingFunction::PgBigintHash,
};
let mut qr = QueryRouter::new();
@@ -658,8 +684,8 @@ mod test {
assert_eq!(qr.active_role, None);
assert_eq!(qr.active_shard, None);
assert_eq!(qr.query_parser_enabled, false);
assert_eq!(qr.primary_reads_enabled, false);
assert_eq!(qr.query_parser_enabled, true);
assert_eq!(qr.primary_reads_enabled, true);
let q1 = simple_query("SET SERVER ROLE TO 'primary'");
assert!(qr.try_execute_command(q1) != None);