Compare commits

..

1 Commits

Author SHA1 Message Date
Lev Kokotov
622891ee5b Release 1.1 2023-07-25 10:26:38 -07:00
23 changed files with 337 additions and 245 deletions

View File

@@ -800,7 +800,7 @@ async fn pause<T>(stream: &mut T, query: &str) -> Result<(), Error>
where where
T: tokio::io::AsyncWrite + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let parts: Vec<&str> = query.split(',').map(|part| part.trim()).collect(); let parts: Vec<&str> = query.split(",").map(|part| part.trim()).collect();
if parts.len() != 2 { if parts.len() != 2 {
error_response( error_response(
@@ -847,7 +847,7 @@ async fn resume<T>(stream: &mut T, query: &str) -> Result<(), Error>
where where
T: tokio::io::AsyncWrite + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let parts: Vec<&str> = query.split(',').map(|part| part.trim()).collect(); let parts: Vec<&str> = query.split(",").map(|part| part.trim()).collect();
if parts.len() != 2 { if parts.len() != 2 {
error_response( error_response(

View File

@@ -12,7 +12,7 @@ pub struct AuthPassthrough {
impl AuthPassthrough { impl AuthPassthrough {
/// Initializes an AuthPassthrough. /// Initializes an AuthPassthrough.
pub fn new<S: ToString>(query: S, user: S, password: S) -> Self { pub fn new(query: &str, user: &str, password: &str) -> Self {
AuthPassthrough { AuthPassthrough {
password: password.to_string(), password: password.to_string(),
query: query.to_string(), query: query.to_string(),

View File

@@ -123,7 +123,7 @@ pub async fn client_entrypoint(
// Client requested a TLS connection. // Client requested a TLS connection.
Ok((ClientConnectionType::Tls, _)) => { Ok((ClientConnectionType::Tls, _)) => {
// TLS settings are configured, will setup TLS now. // TLS settings are configured, will setup TLS now.
if tls_certificate.is_some() { if tls_certificate != None {
debug!("Accepting TLS request"); debug!("Accepting TLS request");
let mut yes = BytesMut::new(); let mut yes = BytesMut::new();
@@ -431,7 +431,7 @@ where
None => "pgcat", None => "pgcat",
}; };
let client_identifier = ClientIdentifier::new(application_name, username, pool_name); let client_identifier = ClientIdentifier::new(&application_name, &username, &pool_name);
let admin = ["pgcat", "pgbouncer"] let admin = ["pgcat", "pgbouncer"]
.iter() .iter()
@@ -930,13 +930,17 @@ where
} }
// Check on plugin results. // Check on plugin results.
if let Some(PluginOutput::Deny(error)) = plugin_output { match plugin_output {
Some(PluginOutput::Deny(error)) => {
self.buffer.clear(); self.buffer.clear();
error_response(&mut self.write, &error).await?; error_response(&mut self.write, &error).await?;
plugin_output = None; plugin_output = None;
continue; continue;
} }
_ => (),
};
// Get a pool instance referenced by the most up-to-date // Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config // pointer. This ensures we always read the latest config
// when starting a query. // when starting a query.
@@ -1209,7 +1213,7 @@ where
// Safe to unwrap because we know this message has a certain length and has the code // Safe to unwrap because we know this message has a certain length and has the code
// This reads the first byte without advancing the internal pointer and mutating the bytes // This reads the first byte without advancing the internal pointer and mutating the bytes
let code = *message.first().unwrap() as char; let code = *message.get(0).unwrap() as char;
trace!("Message: {}", code); trace!("Message: {}", code);
@@ -1327,11 +1331,14 @@ where
let close: Close = (&message).try_into()?; let close: Close = (&message).try_into()?;
if close.is_prepared_statement() && !close.anonymous() { if close.is_prepared_statement() && !close.anonymous() {
if let Some(parse) = self.prepared_statements.get(&close.name) { match self.prepared_statements.get(&close.name) {
Some(parse) => {
server.will_close(&parse.generated_name); server.will_close(&parse.generated_name);
} else {
// A prepared statement slipped through? Not impossible, since we don't support PREPARE yet.
} }
// A prepared statement slipped through? Not impossible, since we don't support PREPARE yet.
None => (),
};
} }
} }
@@ -1369,7 +1376,7 @@ where
self.buffer.put(&message[..]); self.buffer.put(&message[..]);
let first_message_code = (*self.buffer.first().unwrap_or(&0)) as char; let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
// Almost certainly true // Almost certainly true
if first_message_code == 'P' && !prepared_statements_enabled { if first_message_code == 'P' && !prepared_statements_enabled {

View File

@@ -25,7 +25,7 @@ pub struct Args {
} }
pub fn parse() -> Args { pub fn parse() -> Args {
Args::parse() return Args::parse();
} }
#[derive(ValueEnum, Clone, Debug)] #[derive(ValueEnum, Clone, Debug)]

View File

@@ -217,7 +217,8 @@ impl Default for User {
impl User { impl User {
fn validate(&self) -> Result<(), Error> { fn validate(&self) -> Result<(), Error> {
if let Some(min_pool_size) = self.min_pool_size { match self.min_pool_size {
Some(min_pool_size) => {
if min_pool_size > self.pool_size { if min_pool_size > self.pool_size {
error!( error!(
"min_pool_size of {} cannot be larger than pool_size of {}", "min_pool_size of {} cannot be larger than pool_size of {}",
@@ -227,6 +228,9 @@ impl User {
} }
} }
None => (),
};
Ok(()) Ok(())
} }
} }
@@ -627,9 +631,9 @@ impl Pool {
Some(key) => { Some(key) => {
// No quotes in the key so we don't have to compare quoted // No quotes in the key so we don't have to compare quoted
// to unquoted idents. // to unquoted idents.
let key = key.replace('\"', ""); let key = key.replace("\"", "");
if key.split('.').count() != 2 { if key.split(".").count() != 2 {
error!( error!(
"automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`", "automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`",
key, key key, key
@@ -642,7 +646,7 @@ impl Pool {
None => None, None => None,
}; };
for user in self.users.values() { for (_, user) in &self.users {
user.validate()?; user.validate()?;
} }
@@ -814,8 +818,8 @@ pub struct Query {
impl Query { impl Query {
pub fn substitute(&mut self, db: &str, user: &str) { pub fn substitute(&mut self, db: &str, user: &str) {
for col in self.result.iter_mut() { for col in self.result.iter_mut() {
for c in col { for i in 0..col.len() {
*c = c.replace("${USER}", user).replace("${DATABASE}", db); col[i] = col[i].replace("${USER}", user).replace("${DATABASE}", db);
} }
} }
} }
@@ -925,8 +929,8 @@ impl From<&Config> for std::collections::HashMap<String, String> {
( (
format!("pools.{:?}.users", pool_name), format!("pools.{:?}.users", pool_name),
pool.users pool.users
.values() .iter()
.map(|user| &user.username) .map(|(_username, user)| &user.username)
.cloned() .cloned()
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(", "), .join(", "),
@@ -1011,10 +1015,14 @@ impl Config {
Some(tls_certificate) => { Some(tls_certificate) => {
info!("TLS certificate: {}", tls_certificate); info!("TLS certificate: {}", tls_certificate);
if let Some(tls_private_key) = self.general.tls_private_key.clone() { match self.general.tls_private_key.clone() {
Some(tls_private_key) => {
info!("TLS private key: {}", tls_private_key); info!("TLS private key: {}", tls_private_key);
info!("TLS support is enabled"); info!("TLS support is enabled");
} }
None => (),
}
} }
None => { None => {
@@ -1048,8 +1056,8 @@ impl Config {
pool_name, pool_name,
pool_config pool_config
.users .users
.values() .iter()
.map(|user_cfg| user_cfg.pool_size) .map(|(_, user_cfg)| user_cfg.pool_size)
.sum::<u32>() .sum::<u32>()
.to_string() .to_string()
); );
@@ -1206,7 +1214,8 @@ impl Config {
} }
// Validate TLS! // Validate TLS!
if let Some(tls_certificate) = self.general.tls_certificate.clone() { match self.general.tls_certificate.clone() {
Some(tls_certificate) => {
match load_certs(Path::new(&tls_certificate)) { match load_certs(Path::new(&tls_certificate)) {
Ok(_) => { Ok(_) => {
// Cert is okay, but what about the private key? // Cert is okay, but what about the private key?
@@ -1232,6 +1241,8 @@ impl Config {
} }
} }
} }
None => (),
};
for pool in self.pools.values_mut() { for pool in self.pools.values_mut() {
pool.validate()?; pool.validate()?;

View File

@@ -37,11 +37,11 @@ pub struct ClientIdentifier {
} }
impl ClientIdentifier { impl ClientIdentifier {
pub fn new<S: ToString>(application_name: S, username: S, pool_name: S) -> ClientIdentifier { pub fn new(application_name: &str, username: &str, pool_name: &str) -> ClientIdentifier {
ClientIdentifier { ClientIdentifier {
application_name: application_name.to_string(), application_name: application_name.into(),
username: username.to_string(), username: username.into(),
pool_name: pool_name.to_string(), pool_name: pool_name.into(),
} }
} }
} }
@@ -63,10 +63,10 @@ pub struct ServerIdentifier {
} }
impl ServerIdentifier { impl ServerIdentifier {
pub fn new<S: ToString>(username: S, database: S) -> ServerIdentifier { pub fn new(username: &str, database: &str) -> ServerIdentifier {
ServerIdentifier { ServerIdentifier {
username: username.to_string(), username: username.into(),
database: database.to_string(), database: database.into(),
} }
} }
} }
@@ -84,36 +84,41 @@ impl std::fmt::Display for ServerIdentifier {
impl std::fmt::Display for Error { impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match &self { match &self {
Error::ClientSocketError(error, client_identifier) => { &Error::ClientSocketError(error, client_identifier) => write!(
write!(f, "Error reading {error} from client {client_identifier}",) f,
"Error reading {} from client {}",
error, client_identifier
),
&Error::ClientGeneralError(error, client_identifier) => {
write!(f, "{} {}", error, client_identifier)
} }
Error::ClientGeneralError(error, client_identifier) => { &Error::ClientAuthImpossible(username) => write!(
write!(f, "{error} {client_identifier}")
}
Error::ClientAuthImpossible(username) => write!(
f, f,
"Client auth not possible, \ "Client auth not possible, \
no cleartext password set for username: {username} \ no cleartext password set for username: {} \
in config and auth passthrough (query_auth) \ in config and auth passthrough (query_auth) \
is not set up." is not set up.",
username
), ),
Error::ClientAuthPassthroughError(error, client_identifier) => write!( &Error::ClientAuthPassthroughError(error, client_identifier) => write!(
f, f,
"No cleartext password set, \ "No cleartext password set, \
and no auth passthrough could not \ and no auth passthrough could not \
obtain the hash from server for {client_identifier}, \ obtain the hash from server for {}, \
the error was: {error}", the error was: {}",
client_identifier, error
), ),
Error::ServerStartupError(error, server_identifier) => write!( &Error::ServerStartupError(error, server_identifier) => write!(
f, f,
"Error reading {error} on server startup {server_identifier}", "Error reading {} on server startup {}",
error, server_identifier,
), ),
Error::ServerAuthError(error, server_identifier) => { &Error::ServerAuthError(error, server_identifier) => {
write!(f, "{error} for {server_identifier}") write!(f, "{} for {}", error, server_identifier,)
} }
// The rest can use Debug. // The rest can use Debug.
err => write!(f, "{err:?}"), err => write!(f, "{:?}", err),
} }
} }
} }

View File

@@ -25,11 +25,18 @@ pub mod tls;
/// ///
/// * `duration` - A duration of time /// * `duration` - A duration of time
pub fn format_duration(duration: &chrono::Duration) -> String { pub fn format_duration(duration: &chrono::Duration) -> String {
let milliseconds = duration.num_milliseconds() % 1000; let milliseconds = format!("{:0>3}", duration.num_milliseconds() % 1000);
let seconds = duration.num_seconds() % 60;
let minutes = duration.num_minutes() % 60;
let hours = duration.num_hours() % 24;
let days = duration.num_days();
format!("{days}d {hours:0>2}:{minutes:0>2}:{seconds:0>2}.{milliseconds:0>3}") let seconds = format!("{:0>2}", duration.num_seconds() % 60);
let minutes = format!("{:0>2}", duration.num_minutes() % 60);
let hours = format!("{:0>2}", duration.num_hours() % 24);
let days = duration.num_days().to_string();
format!(
"{}d {}:{}:{}.{}",
days, hours, minutes, seconds, milliseconds
)
} }

View File

@@ -160,7 +160,10 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
}; };
Collector::collect(); tokio::task::spawn(async move {
let mut stats_collector = Collector::default();
stats_collector.collect().await;
});
info!("Config autoreloader: {}", match config.general.autoreload { info!("Config autoreloader: {}", match config.general.autoreload {
Some(interval) => format!("{} ms", interval), Some(interval) => format!("{} ms", interval),

View File

@@ -156,10 +156,12 @@ where
match stream.write_all(&startup).await { match stream.write_all(&startup).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => Err(Error::SocketError(format!( Err(err) => {
return Err(Error::SocketError(format!(
"Error writing startup to server socket - Error: {:?}", "Error writing startup to server socket - Error: {:?}",
err err
))), )))
}
} }
} }
@@ -235,8 +237,8 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
let mut md5 = Md5::new(); let mut md5 = Md5::new();
// First pass // First pass
md5.update(password.as_bytes()); md5.update(&password.as_bytes());
md5.update(user.as_bytes()); md5.update(&user.as_bytes());
let output = md5.finalize_reset(); let output = md5.finalize_reset();
@@ -272,7 +274,7 @@ where
{ {
let password = md5_hash_password(user, password, salt); let password = md5_hash_password(user, password, salt);
let mut message = BytesMut::with_capacity(password.len() + 5); let mut message = BytesMut::with_capacity(password.len() as usize + 5);
message.put_u8(b'p'); message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4); message.put_i32(password.len() as i32 + 4);
@@ -286,7 +288,7 @@ where
S: tokio::io::AsyncWrite + std::marker::Unpin, S: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let password = md5_hash_second_pass(hash, salt); let password = md5_hash_second_pass(hash, salt);
let mut message = BytesMut::with_capacity(password.len() + 5); let mut message = BytesMut::with_capacity(password.len() as usize + 5);
message.put_u8(b'p'); message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4); message.put_i32(password.len() as i32 + 4);
@@ -507,7 +509,7 @@ pub fn data_row_nullable(row: &Vec<Option<String>>) -> BytesMut {
data_row.put_i32(column.len() as i32); data_row.put_i32(column.len() as i32);
data_row.put_slice(column); data_row.put_slice(column);
} else { } else {
data_row.put_i32(-1_i32); data_row.put_i32(-1 as i32);
} }
} }
@@ -562,10 +564,12 @@ where
{ {
match stream.write_all(&buf).await { match stream.write_all(&buf).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => Err(Error::SocketError(format!( Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}", "Error writing to socket - Error: {:?}",
err err
))), )))
}
} }
} }
@@ -576,10 +580,12 @@ where
{ {
match stream.write_all(buf).await { match stream.write_all(buf).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => Err(Error::SocketError(format!( Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}", "Error writing to socket - Error: {:?}",
err err
))), )))
}
} }
} }
@@ -590,15 +596,19 @@ where
match stream.write_all(buf).await { match stream.write_all(buf).await {
Ok(_) => match stream.flush().await { Ok(_) => match stream.flush().await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => Err(Error::SocketError(format!( Err(err) => {
return Err(Error::SocketError(format!(
"Error flushing socket - Error: {:?}", "Error flushing socket - Error: {:?}",
err err
))), )))
}
}, },
Err(err) => Err(Error::SocketError(format!( Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}", "Error writing to socket - Error: {:?}",
err err
))), )))
}
} }
} }
@@ -713,7 +723,7 @@ impl BytesMutReader for Cursor<&BytesMut> {
let mut buf = vec![]; let mut buf = vec![];
match self.read_until(b'\0', &mut buf) { match self.read_until(b'\0', &mut buf) {
Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()), Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()),
Err(err) => Err(Error::ParseBytesError(err.to_string())), Err(err) => return Err(Error::ParseBytesError(err.to_string())),
} }
} }
} }

View File

@@ -142,12 +142,12 @@ impl MirroringManager {
}); });
Self { Self {
byte_senders, byte_senders: byte_senders,
disconnect_senders: exit_senders, disconnect_senders: exit_senders,
} }
} }
pub fn send(&mut self, bytes: &BytesMut) { pub fn send(self: &mut Self, bytes: &BytesMut) {
// We want to avoid performing an allocation if we won't be able to send the message // We want to avoid performing an allocation if we won't be able to send the message
// There is a possibility of a race here where we check the capacity and then the channel is // There is a possibility of a race here where we check the capacity and then the channel is
// closed or the capacity is reduced to 0, but mirroring is best effort anyway // closed or the capacity is reduced to 0, but mirroring is best effort anyway
@@ -169,7 +169,7 @@ impl MirroringManager {
}); });
} }
pub fn disconnect(&mut self) { pub fn disconnect(self: &mut Self) {
self.disconnect_senders self.disconnect_senders
.iter_mut() .iter_mut()
.for_each(|sender| match sender.try_send(()) { .for_each(|sender| match sender.try_send(()) {

View File

@@ -92,7 +92,7 @@ impl<'a> Plugin for Intercept<'a> {
.map(|s| { .map(|s| {
let s = s.as_str().to_string(); let s = s.as_str().to_string();
if s.is_empty() { if s == "" {
None None
} else { } else {
Some(s) Some(s)

View File

@@ -30,7 +30,6 @@ pub enum PluginOutput {
Intercept(BytesMut), Intercept(BytesMut),
} }
#[allow(clippy::ptr_arg)]
#[async_trait] #[async_trait]
pub trait Plugin { pub trait Plugin {
// Run before the query is sent to the server. // Run before the query is sent to the server.

View File

@@ -20,7 +20,7 @@ impl<'a> Prewarmer<'a> {
self.server.address(), self.server.address(),
query query
); );
self.server.query(query).await?; self.server.query(&query).await?;
} }
Ok(()) Ok(())

View File

@@ -31,7 +31,7 @@ impl<'a> Plugin for QueryLogger<'a> {
.map(|q| q.to_string()) .map(|q| q.to_string())
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join("; "); .join("; ");
info!("[pool: {}][user: {}] {}", self.db, self.user, query); info!("[pool: {}][user: {}] {}", self.user, self.db, query);
Ok(PluginOutput::Allow) Ok(PluginOutput::Allow)
} }

View File

@@ -30,22 +30,27 @@ impl<'a> Plugin for TableAccess<'a> {
return Ok(PluginOutput::Allow); return Ok(PluginOutput::Allow);
} }
let control_flow = visit_relations(ast, |relation| { let mut found = None;
let relation = relation.to_string();
let table_name = relation.split('.').last().unwrap().to_string();
if self.tables.contains(&table_name) { visit_relations(ast, |relation| {
ControlFlow::Break(table_name) let relation = relation.to_string();
let parts = relation.split(".").collect::<Vec<&str>>();
let table_name = parts.last().unwrap();
if self.tables.contains(&table_name.to_string()) {
found = Some(table_name.to_string());
ControlFlow::<()>::Break(())
} else { } else {
ControlFlow::Continue(()) ControlFlow::<()>::Continue(())
} }
}); });
if let ControlFlow::Break(found) = control_flow { if let Some(found) = found {
debug!("Blocking access to table \"{found}\""); debug!("Blocking access to table \"{}\"", found);
Ok(PluginOutput::Deny(format!( Ok(PluginOutput::Deny(format!(
"permission for table \"{found}\" denied", "permission for table \"{}\" denied",
found
))) )))
} else { } else {
Ok(PluginOutput::Allow) Ok(PluginOutput::Allow)

View File

@@ -229,7 +229,8 @@ impl ConnectionPool {
let old_pool_ref = get_pool(pool_name, &user.username); let old_pool_ref = get_pool(pool_name, &user.username);
let identifier = PoolIdentifier::new(pool_name, &user.username); let identifier = PoolIdentifier::new(pool_name, &user.username);
if let Some(pool) = old_pool_ref { match old_pool_ref {
Some(pool) => {
// If the pool hasn't changed, get existing reference and insert it into the new_pools. // If the pool hasn't changed, get existing reference and insert it into the new_pools.
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
if pool.config_hash == new_pool_hash_value { if pool.config_hash == new_pool_hash_value {
@@ -241,6 +242,8 @@ impl ConnectionPool {
continue; continue;
} }
} }
None => (),
}
info!( info!(
"[pool: {}][user: {}] creating new pool", "[pool: {}][user: {}] creating new pool",
@@ -625,7 +628,7 @@ impl ConnectionPool {
let mut force_healthcheck = false; let mut force_healthcheck = false;
if self.is_banned(address) { if self.is_banned(address) {
if self.try_unban(address).await { if self.try_unban(&address).await {
force_healthcheck = true; force_healthcheck = true;
} else { } else {
debug!("Address {:?} is banned", address); debug!("Address {:?} is banned", address);
@@ -745,8 +748,8 @@ impl ConnectionPool {
// Don't leave a bad connection in the pool. // Don't leave a bad connection in the pool.
server.mark_bad(); server.mark_bad();
self.ban(address, BanReason::FailedHealthCheck, Some(client_info)); self.ban(&address, BanReason::FailedHealthCheck, Some(client_info));
false return false;
} }
/// Ban an address (i.e. replica). It no longer will serve /// Ban an address (i.e. replica). It no longer will serve
@@ -858,10 +861,10 @@ impl ConnectionPool {
let guard = self.banlist.read(); let guard = self.banlist.read();
for banlist in guard.iter() { for banlist in guard.iter() {
for (address, (reason, timestamp)) in banlist.iter() { for (address, (reason, timestamp)) in banlist.iter() {
bans.push((address.clone(), (reason.clone(), *timestamp))); bans.push((address.clone(), (reason.clone(), timestamp.clone())));
} }
} }
bans return bans;
} }
/// Get the address from the host url /// Get the address from the host url
@@ -918,7 +921,7 @@ impl ConnectionPool {
} }
let busy = provisioned - idle; let busy = provisioned - idle;
debug!("{:?} has {:?} busy connections", address, busy); debug!("{:?} has {:?} busy connections", address, busy);
busy return busy;
} }
} }

View File

@@ -67,7 +67,6 @@ static CUSTOM_SQL_REGEX_SET: OnceCell<RegexSet> = OnceCell::new();
static CUSTOM_SQL_REGEX_LIST: OnceCell<Vec<Regex>> = OnceCell::new(); static CUSTOM_SQL_REGEX_LIST: OnceCell<Vec<Regex>> = OnceCell::new();
/// The query router. /// The query router.
#[derive(Default)]
pub struct QueryRouter { pub struct QueryRouter {
/// Which shard we should be talking to right now. /// Which shard we should be talking to right now.
active_shard: Option<usize>, active_shard: Option<usize>,
@@ -92,7 +91,7 @@ impl QueryRouter {
/// One-time initialization of regexes /// One-time initialization of regexes
/// that parse our custom SQL protocol. /// that parse our custom SQL protocol.
pub fn setup() -> bool { pub fn setup() -> bool {
let set = match RegexSet::new(CUSTOM_SQL_REGEXES) { let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) {
Ok(rgx) => rgx, Ok(rgx) => rgx,
Err(err) => { Err(err) => {
error!("QueryRouter::setup Could not compile regex set: {:?}", err); error!("QueryRouter::setup Could not compile regex set: {:?}", err);
@@ -117,8 +116,15 @@ impl QueryRouter {
/// Create a new instance of the query router. /// Create a new instance of the query router.
/// Each client gets its own. /// Each client gets its own.
pub fn new() -> Self { pub fn new() -> QueryRouter {
Self::default() QueryRouter {
active_shard: None,
active_role: None,
query_parser_enabled: None,
primary_reads_enabled: None,
pool_settings: PoolSettings::default(),
placeholders: Vec::new(),
}
} }
/// Pool settings can change because of a config reload. /// Pool settings can change because of a config reload.
@@ -126,7 +132,7 @@ impl QueryRouter {
self.pool_settings = pool_settings; self.pool_settings = pool_settings;
} }
pub fn pool_settings(&self) -> &PoolSettings { pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings {
&self.pool_settings &self.pool_settings
} }
@@ -137,7 +143,7 @@ impl QueryRouter {
let code = message_cursor.get_u8() as char; let code = message_cursor.get_u8() as char;
// Check for any sharding regex matches in any queries // Check for any sharding regex matches in any queries
match code { match code as char {
// For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement // For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
'P' | 'Q' => { 'P' | 'Q' => {
if self.pool_settings.shard_id_regex.is_some() if self.pool_settings.shard_id_regex.is_some()
@@ -391,10 +397,14 @@ impl QueryRouter {
// or discard shard selection. If they point to the same shard though, // or discard shard selection. If they point to the same shard though,
// we can let them through as-is. // we can let them through as-is.
// This is basically building a database now :) // This is basically building a database now :)
if let Some(shard) = self.infer_shard(query) { match self.infer_shard(query) {
Some(shard) => {
self.active_shard = Some(shard); self.active_shard = Some(shard);
debug!("Automatically using shard: {:?}", self.active_shard); debug!("Automatically using shard: {:?}", self.active_shard);
} }
None => (),
};
} }
None => (), None => (),
@@ -566,8 +576,8 @@ impl QueryRouter {
.automatic_sharding_key .automatic_sharding_key
.as_ref() .as_ref()
.unwrap() .unwrap()
.split('.') .split(".")
.map(Ident::new) .map(|ident| Ident::new(ident))
.collect::<Vec<Ident>>(); .collect::<Vec<Ident>>();
// Sharding key must be always fully qualified // Sharding key must be always fully qualified
@@ -583,7 +593,7 @@ impl QueryRouter {
Expr::Identifier(ident) => { Expr::Identifier(ident) => {
// Only if we're dealing with only one table // Only if we're dealing with only one table
// and there is no ambiguity // and there is no ambiguity
if ident.value == sharding_key[1].value { if &ident.value == &sharding_key[1].value {
// Sharding key is unique enough, don't worry about // Sharding key is unique enough, don't worry about
// table names. // table names.
if &sharding_key[0].value == "*" { if &sharding_key[0].value == "*" {
@@ -596,13 +606,13 @@ impl QueryRouter {
// SELECT * FROM t WHERE sharding_key = 5 // SELECT * FROM t WHERE sharding_key = 5
// Make sure the table name from the sharding key matches // Make sure the table name from the sharding key matches
// the table name from the query. // the table name from the query.
found = sharding_key[0].value == table[0].value; found = &sharding_key[0].value == &table[0].value;
} else if table.len() == 2 { } else if table.len() == 2 {
// Table name is fully qualified with the schema: e.g. // Table name is fully qualified with the schema: e.g.
// SELECT * FROM public.t WHERE sharding_key = 5 // SELECT * FROM public.t WHERE sharding_key = 5
// Ignore the schema (TODO: at some point, we want schema support) // Ignore the schema (TODO: at some point, we want schema support)
// and use the table name only. // and use the table name only.
found = sharding_key[0].value == table[1].value; found = &sharding_key[0].value == &table[1].value;
} else { } else {
debug!("Got table name with more than two idents, which is not possible"); debug!("Got table name with more than two idents, which is not possible");
} }
@@ -614,8 +624,8 @@ impl QueryRouter {
// The key is fully qualified in the query, // The key is fully qualified in the query,
// it will exist or Postgres will throw an error. // it will exist or Postgres will throw an error.
if idents.len() == 2 { if idents.len() == 2 {
found = sharding_key[0].value == idents[0].value found = &sharding_key[0].value == &idents[0].value
&& sharding_key[1].value == idents[1].value; && &sharding_key[1].value == &idents[1].value;
} }
// TODO: key can have schema as well, e.g. public.data.id (len == 3) // TODO: key can have schema as well, e.g. public.data.id (len == 3)
} }
@@ -647,7 +657,7 @@ impl QueryRouter {
} }
Expr::Value(Value::Placeholder(placeholder)) => { Expr::Value(Value::Placeholder(placeholder)) => {
match placeholder.replace('$', "").parse::<i16>() { match placeholder.replace("$", "").parse::<i16>() {
Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)), Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)),
Err(_) => { Err(_) => {
debug!( debug!(
@@ -673,9 +683,12 @@ impl QueryRouter {
match &*query.body { match &*query.body {
SetExpr::Query(query) => { SetExpr::Query(query) => {
if let Some(shard) = self.infer_shard(query) { match self.infer_shard(&*query) {
Some(shard) => {
shards.insert(shard); shards.insert(shard);
} }
None => (),
};
} }
// SELECT * FROM ... // SELECT * FROM ...
@@ -685,22 +698,38 @@ impl QueryRouter {
let mut table_names = Vec::new(); let mut table_names = Vec::new();
for table in select.from.iter() { for table in select.from.iter() {
if let TableFactor::Table { name, .. } = &table.relation { match &table.relation {
TableFactor::Table { name, .. } => {
table_names.push(name.0.clone()); table_names.push(name.0.clone());
} }
_ => (),
};
// Get table names from all the joins. // Get table names from all the joins.
for join in table.joins.iter() { for join in table.joins.iter() {
if let TableFactor::Table { name, .. } = &join.relation { match &join.relation {
TableFactor::Table { name, .. } => {
table_names.push(name.0.clone()); table_names.push(name.0.clone());
} }
_ => (),
};
// We can filter results based on join conditions, e.g. // We can filter results based on join conditions, e.g.
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5; // SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
if let JoinOperator::Inner(JoinConstraint::On(expr)) = &join.join_operator { match &join.join_operator {
JoinOperator::Inner(inner_join) => match &inner_join {
JoinConstraint::On(expr) => {
// Parse the selection criteria later. // Parse the selection criteria later.
exprs.push(expr.clone()); exprs.push(expr.clone());
} }
_ => (),
},
_ => (),
};
} }
} }
@@ -774,16 +803,16 @@ impl QueryRouter {
db: &self.pool_settings.db, db: &self.pool_settings.db,
}; };
let _ = query_logger.run(self, ast).await; let _ = query_logger.run(&self, ast).await;
} }
if let Some(ref intercept) = plugins.intercept { if let Some(ref intercept) = plugins.intercept {
let mut intercept = Intercept { let mut intercept = Intercept {
enabled: intercept.enabled, enabled: intercept.enabled,
config: intercept, config: &intercept,
}; };
let result = intercept.run(self, ast).await; let result = intercept.run(&self, ast).await;
if let Ok(PluginOutput::Intercept(output)) = result { if let Ok(PluginOutput::Intercept(output)) = result {
return Ok(PluginOutput::Intercept(output)); return Ok(PluginOutput::Intercept(output));
@@ -796,7 +825,7 @@ impl QueryRouter {
tables: &table_access.tables, tables: &table_access.tables,
}; };
let result = table_access.run(self, ast).await; let result = table_access.run(&self, ast).await;
if let Ok(PluginOutput::Deny(error)) = result { if let Ok(PluginOutput::Deny(error)) = result {
return Ok(PluginOutput::Deny(error)); return Ok(PluginOutput::Deny(error));
@@ -832,7 +861,7 @@ impl QueryRouter {
/// Should we attempt to parse queries? /// Should we attempt to parse queries?
pub fn query_parser_enabled(&self) -> bool { pub fn query_parser_enabled(&self) -> bool {
match self.query_parser_enabled { let enabled = match self.query_parser_enabled {
None => { None => {
debug!( debug!(
"Using pool settings, query_parser_enabled: {}", "Using pool settings, query_parser_enabled: {}",
@@ -848,7 +877,9 @@ impl QueryRouter {
); );
value value
} }
} };
enabled
} }
pub fn primary_reads_enabled(&self) -> bool { pub fn primary_reads_enabled(&self) -> bool {
@@ -879,14 +910,10 @@ mod test {
fn test_infer_replica() { fn test_infer_replica() {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
assert!(qr assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None);
.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"))
.is_some());
assert!(qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
assert!(qr assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
let queries = vec![ let queries = vec![
simple_query("SELECT * FROM items WHERE id = 5"), simple_query("SELECT * FROM items WHERE id = 5"),
@@ -927,9 +954,7 @@ mod test {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
let query = simple_query("SELECT * FROM items WHERE id = 5"); let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None);
.try_execute_command(&simple_query("SET PRIMARY READS TO on"))
.is_some());
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), None); assert_eq!(qr.role(), None);
@@ -940,9 +965,7 @@ mod test {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")); qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"));
assert!(qr assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
let prepared_stmt = BytesMut::from( let prepared_stmt = BytesMut::from(
&b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..], &b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..],
@@ -1110,11 +1133,9 @@ mod test {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
let query = simple_query("SET SERVER ROLE TO 'auto'"); let query = simple_query("SET SERVER ROLE TO 'auto'");
assert!(qr assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None);
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
assert!(qr.try_execute_command(&query).is_some()); assert!(qr.try_execute_command(&query) != None);
assert!(qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
assert_eq!(qr.role(), None); assert_eq!(qr.role(), None);
@@ -1128,7 +1149,7 @@ mod test {
assert!(qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
let query = simple_query("SET SERVER ROLE TO 'default'"); let query = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(&query).is_some()); assert!(qr.try_execute_command(&query) != None);
assert!(!qr.query_parser_enabled()); assert!(!qr.query_parser_enabled());
} }
@@ -1173,11 +1194,11 @@ mod test {
assert!(!qr.primary_reads_enabled()); assert!(!qr.primary_reads_enabled());
let q1 = simple_query("SET SERVER ROLE TO 'primary'"); let q1 = simple_query("SET SERVER ROLE TO 'primary'");
assert!(qr.try_execute_command(&q1).is_some()); assert!(qr.try_execute_command(&q1) != None);
assert_eq!(qr.active_role.unwrap(), Role::Primary); assert_eq!(qr.active_role.unwrap(), Role::Primary);
let q2 = simple_query("SET SERVER ROLE TO 'default'"); let q2 = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(&q2).is_some()); assert!(qr.try_execute_command(&q2) != None);
assert_eq!(qr.active_role.unwrap(), pool_settings.default_role); assert_eq!(qr.active_role.unwrap(), pool_settings.default_role);
} }
@@ -1242,17 +1263,17 @@ mod test {
// Make sure setting it works // Make sure setting it works
let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;"); let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;");
assert!(qr.try_execute_command(&q1).is_none()); assert!(qr.try_execute_command(&q1) == None);
assert_eq!(qr.active_shard, Some(1)); assert_eq!(qr.active_shard, Some(1));
// And make sure changing it works // And make sure changing it works
let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;"); let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;");
assert!(qr.try_execute_command(&q2).is_none()); assert!(qr.try_execute_command(&q2) == None);
assert_eq!(qr.active_shard, Some(0)); assert_eq!(qr.active_shard, Some(0));
// Validate setting by shard with expected shard copied from sharding.rs tests // Validate setting by shard with expected shard copied from sharding.rs tests
let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;"); let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;");
assert!(qr.try_execute_command(&q2).is_none()); assert!(qr.try_execute_command(&q2) == None);
assert_eq!(qr.active_shard, Some(2)); assert_eq!(qr.active_shard, Some(2));
} }
@@ -1390,11 +1411,9 @@ mod test {
}; };
QueryRouter::setup(); QueryRouter::setup();
let pool_settings = PoolSettings { let mut pool_settings = PoolSettings::default();
query_parser_enabled: true, pool_settings.query_parser_enabled = true;
plugins: Some(plugins), pool_settings.plugins = Some(plugins);
..Default::default()
};
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings); qr.update_pool_settings(pool_settings);

View File

@@ -79,12 +79,12 @@ impl ScramSha256 {
let server_message = Message::parse(message)?; let server_message = Message::parse(message)?;
if !server_message.nonce.starts_with(&self.nonce) { if !server_message.nonce.starts_with(&self.nonce) {
return Err(Error::ProtocolSyncError("SCRAM".to_string())); return Err(Error::ProtocolSyncError(format!("SCRAM")));
} }
let salt = match general_purpose::STANDARD.decode(&server_message.salt) { let salt = match general_purpose::STANDARD.decode(&server_message.salt) {
Ok(salt) => salt, Ok(salt) => salt,
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())), Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
}; };
let salted_password = Self::hi( let salted_password = Self::hi(
@@ -166,9 +166,9 @@ impl ScramSha256 {
pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> { pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> {
let final_message = FinalMessage::parse(message)?; let final_message = FinalMessage::parse(message)?;
let verifier = match general_purpose::STANDARD.decode(final_message.value) { let verifier = match general_purpose::STANDARD.decode(&final_message.value) {
Ok(verifier) => verifier, Ok(verifier) => verifier,
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())), Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
}; };
let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) { let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
@@ -230,14 +230,14 @@ impl Message {
.collect::<Vec<String>>(); .collect::<Vec<String>>();
if parts.len() != 3 { if parts.len() != 3 {
return Err(Error::ProtocolSyncError("SCRAM".to_string())); return Err(Error::ProtocolSyncError(format!("SCRAM")));
} }
let nonce = str::replace(&parts[0], "r=", ""); let nonce = str::replace(&parts[0], "r=", "");
let salt = str::replace(&parts[1], "s=", ""); let salt = str::replace(&parts[1], "s=", "");
let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() { let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
Ok(iterations) => iterations, Ok(iterations) => iterations,
Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())), Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
}; };
Ok(Message { Ok(Message {
@@ -257,7 +257,7 @@ impl FinalMessage {
/// Parse the server final validation message. /// Parse the server final validation message.
pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> { pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
if !message.starts_with(b"v=") || message.len() < 4 { if !message.starts_with(b"v=") || message.len() < 4 {
return Err(Error::ProtocolSyncError("SCRAM".to_string())); return Err(Error::ProtocolSyncError(format!("SCRAM")));
} }
Ok(FinalMessage { Ok(FinalMessage {

View File

@@ -316,7 +316,10 @@ impl Server {
// Something else? // Something else?
m => { m => {
return Err(Error::SocketError(format!("Unknown message: {}", { m }))); return Err(Error::SocketError(format!(
"Unknown message: {}",
m as char
)));
} }
} }
} else { } else {
@@ -334,18 +337,27 @@ impl Server {
None => &user.username, None => &user.username,
}; };
let password = user.server_password.as_ref(); let password = match user.server_password {
Some(ref server_password) => Some(server_password),
None => match user.password {
Some(ref password) => Some(password),
None => None,
},
};
startup(&mut stream, username, database).await?; startup(&mut stream, username, database).await?;
let mut server_info = BytesMut::new(); let mut server_info = BytesMut::new();
let mut process_id: i32 = 0; let mut process_id: i32 = 0;
let mut secret_key: i32 = 0; let mut secret_key: i32 = 0;
let server_identifier = ServerIdentifier::new(username, database); let server_identifier = ServerIdentifier::new(username, &database);
// We'll be handling multiple packets, but they will all be structured the same. // We'll be handling multiple packets, but they will all be structured the same.
// We'll loop here until this exchange is complete. // We'll loop here until this exchange is complete.
let mut scram: Option<ScramSha256> = password.map(|password| ScramSha256::new(password)); let mut scram: Option<ScramSha256> = match password {
Some(password) => Some(ScramSha256::new(password)),
None => None,
};
loop { loop {
let code = match stream.read_u8().await { let code = match stream.read_u8().await {
@@ -741,7 +753,7 @@ impl Server {
self.mirror_send(messages); self.mirror_send(messages);
self.stats().data_sent(messages.len()); self.stats().data_sent(messages.len());
match write_all_flush(&mut self.stream, messages).await { match write_all_flush(&mut self.stream, &messages).await {
Ok(_) => { Ok(_) => {
// Successfully sent to server // Successfully sent to server
self.last_activity = SystemTime::now(); self.last_activity = SystemTime::now();
@@ -1187,14 +1199,16 @@ impl Server {
} }
pub fn mirror_send(&mut self, bytes: &BytesMut) { pub fn mirror_send(&mut self, bytes: &BytesMut) {
if let Some(manager) = self.mirror_manager.as_mut() { match self.mirror_manager.as_mut() {
manager.send(bytes); Some(manager) => manager.send(bytes),
None => (),
} }
} }
pub fn mirror_disconnect(&mut self) { pub fn mirror_disconnect(&mut self) {
if let Some(manager) = self.mirror_manager.as_mut() { match self.mirror_manager.as_mut() {
manager.disconnect(); Some(manager) => manager.disconnect(),
None => (),
} }
} }
@@ -1222,7 +1236,7 @@ impl Server {
server.send(&simple_query(query)).await?; server.send(&simple_query(query)).await?;
let mut message = server.recv().await?; let mut message = server.recv().await?;
parse_query_message(&mut message).await Ok(parse_query_message(&mut message).await?)
} }
} }

View File

@@ -64,7 +64,7 @@ impl Sharder {
fn sha1(&self, key: i64) -> usize { fn sha1(&self, key: i64) -> usize {
let mut hasher = Sha1::new(); let mut hasher = Sha1::new();
hasher.update(key.to_string().as_bytes()); hasher.update(&key.to_string().as_bytes());
let result = hasher.finalize(); let result = hasher.finalize();

View File

@@ -77,12 +77,13 @@ impl Reporter {
/// The statistics collector which used for calculating averages /// The statistics collector which used for calculating averages
/// There is only one collector (kind of like a singleton) /// There is only one collector (kind of like a singleton)
/// it updates averages every 15 seconds. /// it updates averages every 15 seconds.
pub struct Collector; #[derive(Default)]
pub struct Collector {}
impl Collector { impl Collector {
/// The statistics collection handler. It will collect statistics /// The statistics collection handler. It will collect statistics
/// for `address_id`s starting at 0 up to `addresses`. /// for `address_id`s starting at 0 up to `addresses`.
pub fn collect() { pub async fn collect(&mut self) {
info!("Events reporter started"); info!("Events reporter started");
tokio::task::spawn(async move { tokio::task::spawn(async move {

View File

@@ -86,11 +86,11 @@ impl PoolStats {
} }
} }
map return map;
} }
pub fn generate_header() -> Vec<(&'static str, DataType)> { pub fn generate_header() -> Vec<(&'static str, DataType)> {
vec![ return vec![
("database", DataType::Text), ("database", DataType::Text),
("user", DataType::Text), ("user", DataType::Text),
("pool_mode", DataType::Text), ("pool_mode", DataType::Text),
@@ -105,11 +105,11 @@ impl PoolStats {
("sv_login", DataType::Numeric), ("sv_login", DataType::Numeric),
("maxwait", DataType::Numeric), ("maxwait", DataType::Numeric),
("maxwait_us", DataType::Numeric), ("maxwait_us", DataType::Numeric),
] ];
} }
pub fn generate_row(&self) -> Vec<String> { pub fn generate_row(&self) -> Vec<String> {
vec![ return vec![
self.identifier.db.clone(), self.identifier.db.clone(),
self.identifier.user.clone(), self.identifier.user.clone(),
self.mode.to_string(), self.mode.to_string(),
@@ -124,7 +124,7 @@ impl PoolStats {
self.sv_login.to_string(), self.sv_login.to_string(),
(self.maxwait / 1_000_000).to_string(), (self.maxwait / 1_000_000).to_string(),
(self.maxwait % 1_000_000).to_string(), (self.maxwait % 1_000_000).to_string(),
] ];
} }
} }

View File

@@ -44,17 +44,25 @@ impl Tls {
pub fn new() -> Result<Self, Error> { pub fn new() -> Result<Self, Error> {
let config = get_config(); let config = get_config();
let certs = load_certs(Path::new(&config.general.tls_certificate.unwrap())) let certs = match load_certs(Path::new(&config.general.tls_certificate.unwrap())) {
.map_err(|_| Error::TlsError)?; Ok(certs) => certs,
let key_der = load_keys(Path::new(&config.general.tls_private_key.unwrap())) Err(_) => return Err(Error::TlsError),
.map_err(|_| Error::TlsError)? };
.remove(0);
let config = rustls::ServerConfig::builder() let mut keys = match load_keys(Path::new(&config.general.tls_private_key.unwrap())) {
Ok(keys) => keys,
Err(_) => return Err(Error::TlsError),
};
let config = match rustls::ServerConfig::builder()
.with_safe_defaults() .with_safe_defaults()
.with_no_client_auth() .with_no_client_auth()
.with_single_cert(certs, key_der) .with_single_cert(certs, keys.remove(0))
.map_err(|_| Error::TlsError)?; .map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err))
{
Ok(c) => c,
Err(_) => return Err(Error::TlsError),
};
Ok(Tls { Ok(Tls {
acceptor: TlsAcceptor::from(Arc::new(config)), acceptor: TlsAcceptor::from(Arc::new(config)),