Change sharding config to enum and move validation of configs into public functions (#178)

Moves config validation to own functions to enable tools to use them
Moves sharding config to enum
Makes defaults public
Make connect_timeout on pool and option which is overwritten by general connect_timeout
This commit is contained in:
zainkabani
2022-09-28 09:50:14 -04:00
committed by GitHub
parent af064ef447
commit 24f5eec3ea
3 changed files with 169 additions and 138 deletions

View File

@@ -13,6 +13,7 @@ use toml;
use crate::errors::Error;
use crate::pool::{ClientServerMap, ConnectionPool};
use crate::sharding::ShardingFunction;
use crate::tls::{load_certs, load_keys};
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
@@ -179,31 +180,31 @@ pub struct General {
}
impl General {
fn default_host() -> String {
pub fn default_host() -> String {
"0.0.0.0".into()
}
fn default_port() -> i16 {
pub fn default_port() -> i16 {
5432
}
fn default_connect_timeout() -> u64 {
pub fn default_connect_timeout() -> u64 {
1000
}
fn default_shutdown_timeout() -> u64 {
pub fn default_shutdown_timeout() -> u64 {
60000
}
fn default_healthcheck_timeout() -> u64 {
pub fn default_healthcheck_timeout() -> u64 {
1000
}
fn default_healthcheck_delay() -> u64 {
pub fn default_healthcheck_delay() -> u64 {
30000
}
fn default_ban_time() -> i64 {
pub fn default_ban_time() -> i64 {
60
}
}
@@ -211,15 +212,15 @@ impl General {
impl Default for General {
fn default() -> General {
General {
host: General::default_host(),
port: General::default_port(),
host: Self::default_host(),
port: Self::default_port(),
enable_prometheus_exporter: Some(false),
prometheus_exporter_port: 9930,
connect_timeout: General::default_connect_timeout(),
shutdown_timeout: General::default_shutdown_timeout(),
healthcheck_timeout: General::default_healthcheck_timeout(),
healthcheck_delay: General::default_healthcheck_delay(),
ban_time: General::default_ban_time(),
shutdown_timeout: Self::default_shutdown_timeout(),
healthcheck_timeout: Self::default_healthcheck_timeout(),
healthcheck_delay: Self::default_healthcheck_delay(),
ban_time: Self::default_ban_time(),
autoreload: false,
tls_certificate: None,
tls_private_key: None,
@@ -263,31 +264,61 @@ pub struct Pool {
#[serde(default)] // False
pub primary_reads_enabled: bool,
#[serde(default = "General::default_connect_timeout")]
pub connect_timeout: u64,
pub connect_timeout: Option<u64>,
pub sharding_function: String,
pub sharding_function: ShardingFunction,
pub shards: BTreeMap<String, Shard>,
pub users: BTreeMap<String, User>,
}
impl Pool {
fn default_pool_mode() -> PoolMode {
pub fn default_pool_mode() -> PoolMode {
PoolMode::Transaction
}
pub fn validate(&self) -> Result<(), Error> {
match self.default_role.as_ref() {
"any" => (),
"primary" => (),
"replica" => (),
other => {
error!(
"Query router default_role must be 'primary', 'replica', or 'any', got: '{}'",
other
);
return Err(Error::BadConfig);
}
};
for (shard_idx, shard) in &self.shards {
match shard_idx.parse::<usize>() {
Ok(_) => (),
Err(_) => {
error!(
"Shard '{}' is not a valid number, shards must be numbered starting at 0",
shard_idx
);
return Err(Error::BadConfig);
}
};
shard.validate()?;
}
Ok(())
}
}
impl Default for Pool {
fn default() -> Pool {
Pool {
pool_mode: Pool::default_pool_mode(),
pool_mode: Self::default_pool_mode(),
shards: BTreeMap::from([(String::from("1"), Shard::default())]),
users: BTreeMap::default(),
default_role: String::from("any"),
query_parser_enabled: false,
primary_reads_enabled: false,
sharding_function: "pg_bigint_hash".to_string(),
connect_timeout: General::default_connect_timeout(),
sharding_function: ShardingFunction::PgBigintHash,
connect_timeout: None,
}
}
}
@@ -306,6 +337,45 @@ pub struct Shard {
pub servers: Vec<ServerConfig>,
}
impl Shard {
pub fn validate(&self) -> Result<(), Error> {
// We use addresses as unique identifiers,
// let's make sure they are unique in the config as well.
let mut dup_check = HashSet::new();
let mut primary_count = 0;
if self.servers.len() == 0 {
error!("Shard {} has no servers configured", self.database);
return Err(Error::BadConfig);
}
for server in &self.servers {
dup_check.insert(server);
// Check that we define only zero or one primary.
match server.role {
Role::Primary => primary_count += 1,
_ => (),
};
}
if primary_count > 1 {
error!(
"Shard {} has more than on primary configured",
self.database
);
return Err(Error::BadConfig);
}
if dup_check.len() != self.servers.len() {
error!("Shard {} contains duplicate server configs", self.database);
return Err(Error::BadConfig);
}
Ok(())
}
}
impl Default for Shard {
fn default() -> Shard {
Shard {
@@ -326,7 +396,7 @@ pub struct Config {
// so we should always put simple fields before nested fields
// in all serializable structs to avoid ValueAfterTable errors
// These errors occur when the toml serializer is about to produce
// ambigous toml structure like the one below
// ambiguous toml structure like the one below
// [main]
// field1_under_main = 1
// field2_under_main = 2
@@ -341,7 +411,7 @@ pub struct Config {
}
impl Config {
fn default_path() -> String {
pub fn default_path() -> String {
String::from("pgcat.toml")
}
}
@@ -349,7 +419,7 @@ impl Config {
impl Default for Config {
fn default() -> Config {
Config {
path: Config::default_path(),
path: Self::default_path(),
general: General::default(),
pools: HashMap::default(),
}
@@ -381,7 +451,7 @@ impl From<&Config> for std::collections::HashMap<String, String> {
),
(
format!("pools.{}.sharding_function", pool_name),
pool.sharding_function.clone(),
pool.sharding_function.to_string(),
),
(
format!("pools.{:?}.shard_count", pool_name),
@@ -477,9 +547,18 @@ impl Config {
"[pool: {}] Pool mode: {:?}",
pool_name, pool_config.pool_mode
);
let connect_timeout = match pool_config.connect_timeout {
Some(connect_timeout) => connect_timeout,
None => self.general.connect_timeout,
};
info!(
"[pool: {}] Connection timeout: {}ms",
pool_name, connect_timeout
);
info!(
"[pool: {}] Sharding function: {}",
pool_name, pool_config.sharding_function
pool_name,
pool_config.sharding_function.to_string()
);
info!(
"[pool: {}] Primary reads: {}",
@@ -512,6 +591,50 @@ impl Config {
}
}
}
pub fn validate(&mut self) -> Result<(), Error> {
// Validate TLS!
match self.general.tls_certificate.clone() {
Some(tls_certificate) => {
match load_certs(&Path::new(&tls_certificate)) {
Ok(_) => {
// Cert is okay, but what about the private key?
match self.general.tls_private_key.clone() {
Some(tls_private_key) => {
match load_keys(&Path::new(&tls_private_key)) {
Ok(_) => (),
Err(err) => {
error!(
"tls_private_key is incorrectly configured: {:?}",
err
);
return Err(Error::BadConfig);
}
}
}
None => {
error!("tls_certificate is set, but the tls_private_key is not");
return Err(Error::BadConfig);
}
};
}
Err(err) => {
error!("tls_certificate is incorrectly configured: {:?}", err);
return Err(Error::BadConfig);
}
}
}
None => (),
};
for (_, pool) in &mut self.pools {
pool.validate()?;
}
Ok(())
}
}
/// Get a read-only instance of the configuration
@@ -548,110 +671,7 @@ pub async fn parse(path: &str) -> Result<(), Error> {
}
};
// Validate TLS!
match config.general.tls_certificate.clone() {
Some(tls_certificate) => {
match load_certs(&Path::new(&tls_certificate)) {
Ok(_) => {
// Cert is okay, but what about the private key?
match config.general.tls_private_key.clone() {
Some(tls_private_key) => match load_keys(&Path::new(&tls_private_key)) {
Ok(_) => (),
Err(err) => {
error!("tls_private_key is incorrectly configured: {:?}", err);
return Err(Error::BadConfig);
}
},
None => {
error!("tls_certificate is set, but the tls_private_key is not");
return Err(Error::BadConfig);
}
};
}
Err(err) => {
error!("tls_certificate is incorrectly configured: {:?}", err);
return Err(Error::BadConfig);
}
}
}
None => (),
};
for (pool_name, mut pool) in &mut config.pools {
// Copy the connect timeout over for hashing.
pool.connect_timeout = config.general.connect_timeout;
match pool.sharding_function.as_ref() {
"pg_bigint_hash" => (),
"sha1" => (),
_ => {
error!(
"Supported sharding functions are: 'pg_bigint_hash', 'sha1', got: '{}' in pool {} settings",
pool.sharding_function,
pool_name
);
return Err(Error::BadConfig);
}
};
match pool.default_role.as_ref() {
"any" => (),
"primary" => (),
"replica" => (),
other => {
error!(
"Query router default_role must be 'primary', 'replica', or 'any', got: '{}'",
other
);
return Err(Error::BadConfig);
}
};
for shard in &pool.shards {
// We use addresses as unique identifiers,
// let's make sure they are unique in the config as well.
let mut dup_check = HashSet::new();
let mut primary_count = 0;
match shard.0.parse::<usize>() {
Ok(_) => (),
Err(_) => {
error!(
"Shard '{}' is not a valid number, shards must be numbered starting at 0",
shard.0
);
return Err(Error::BadConfig);
}
};
if shard.1.servers.len() == 0 {
error!("Shard {} has no servers configured", shard.0);
return Err(Error::BadConfig);
}
for server in &shard.1.servers {
dup_check.insert(server);
// Check that we define only zero or one primary.
match server.role {
Role::Primary => primary_count += 1,
_ => (),
};
}
if primary_count > 1 {
error!("Shard {} has more than on primary configured", &shard.0);
return Err(Error::BadConfig);
}
if dup_check.len() != shard.1.servers.len() {
error!("Shard {} contains duplicate server configs", &shard.0);
return Err(Error::BadConfig);
}
}
}
config.validate()?;
config.path = path.to_string();

View File

@@ -181,11 +181,14 @@ impl ConnectionPool {
get_reporter(),
);
let connect_timeout = match pool_config.connect_timeout {
Some(connect_timeout) => connect_timeout,
None => config.general.connect_timeout,
};
let pool = Pool::builder()
.max_size(user.pool_size)
.connection_timeout(std::time::Duration::from_millis(
pool_config.connect_timeout,
))
.connection_timeout(std::time::Duration::from_millis(connect_timeout))
.test_on_check_out(false)
.build(manager)
.await
@@ -221,11 +224,7 @@ impl ConnectionPool {
},
query_parser_enabled: pool_config.query_parser_enabled.clone(),
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: match pool_config.sharding_function.as_str() {
"pg_bigint_hash" => ShardingFunction::PgBigintHash,
"sha1" => ShardingFunction::Sha1,
_ => unreachable!(),
},
sharding_function: pool_config.sharding_function,
},
};

View File

@@ -1,3 +1,4 @@
use serde_derive::{Deserialize, Serialize};
/// Implements various sharding functions.
use sha1::{Digest, Sha1};
@@ -5,12 +6,23 @@ use sha1::{Digest, Sha1};
const PARTITION_HASH_SEED: u64 = 0x7A5B22367996DCFD;
/// The sharding functions we support.
#[derive(Debug, PartialEq, Copy, Clone)]
#[derive(Debug, PartialEq, Copy, Clone, Serialize, Deserialize, Hash, std::cmp::Eq)]
pub enum ShardingFunction {
#[serde(alias = "pg_bigint_hash", alias = "PgBigintHash")]
PgBigintHash,
#[serde(alias = "sha1", alias = "Sha1")]
Sha1,
}
impl ToString for ShardingFunction {
fn to_string(&self) -> String {
match *self {
ShardingFunction::PgBigintHash => "pg_bigint_hash".to_string(),
ShardingFunction::Sha1 => "sha1".to_string(),
}
}
}
/// The sharder.
pub struct Sharder {
/// Number of shards in the cluster.