diff --git a/src/admin.rs b/src/admin.rs index 6c83f9b..c1c897b 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -800,7 +800,7 @@ async fn pause(stream: &mut T, query: &str) -> Result<(), Error> where T: tokio::io::AsyncWrite + std::marker::Unpin, { - let parts: Vec<&str> = query.split(",").map(|part| part.trim()).collect(); + let parts: Vec<&str> = query.split(',').map(|part| part.trim()).collect(); if parts.len() != 2 { error_response( @@ -847,7 +847,7 @@ async fn resume(stream: &mut T, query: &str) -> Result<(), Error> where T: tokio::io::AsyncWrite + std::marker::Unpin, { - let parts: Vec<&str> = query.split(",").map(|part| part.trim()).collect(); + let parts: Vec<&str> = query.split(',').map(|part| part.trim()).collect(); if parts.len() != 2 { error_response( diff --git a/src/client.rs b/src/client.rs index 7d5e979..23315e4 100644 --- a/src/client.rs +++ b/src/client.rs @@ -123,7 +123,7 @@ pub async fn client_entrypoint( // Client requested a TLS connection. Ok((ClientConnectionType::Tls, _)) => { // TLS settings are configured, will setup TLS now. - if tls_certificate != None { + if tls_certificate.is_some() { debug!("Accepting TLS request"); let mut yes = BytesMut::new(); @@ -431,7 +431,7 @@ where None => "pgcat", }; - let client_identifier = ClientIdentifier::new(&application_name, &username, &pool_name); + let client_identifier = ClientIdentifier::new(application_name, username, pool_name); let admin = ["pgcat", "pgbouncer"] .iter() @@ -930,16 +930,12 @@ where } // Check on plugin results. - match plugin_output { - Some(PluginOutput::Deny(error)) => { - self.buffer.clear(); - error_response(&mut self.write, &error).await?; - plugin_output = None; - continue; - } - - _ => (), - }; + if let Some(PluginOutput::Deny(error)) = plugin_output { + self.buffer.clear(); + error_response(&mut self.write, &error).await?; + plugin_output = None; + continue; + } // Get a pool instance referenced by the most up-to-date // pointer. This ensures we always read the latest config @@ -1213,7 +1209,7 @@ where // Safe to unwrap because we know this message has a certain length and has the code // This reads the first byte without advancing the internal pointer and mutating the bytes - let code = *message.get(0).unwrap() as char; + let code = *message.first().unwrap() as char; trace!("Message: {}", code); @@ -1331,14 +1327,11 @@ where let close: Close = (&message).try_into()?; if close.is_prepared_statement() && !close.anonymous() { - match self.prepared_statements.get(&close.name) { - Some(parse) => { - server.will_close(&parse.generated_name); - } - + if let Some(parse) = self.prepared_statements.get(&close.name) { + server.will_close(&parse.generated_name); + } else { // A prepared statement slipped through? Not impossible, since we don't support PREPARE yet. - None => (), - }; + } } } @@ -1376,7 +1369,7 @@ where self.buffer.put(&message[..]); - let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; + let first_message_code = (*self.buffer.first().unwrap_or(&0)) as char; // Almost certainly true if first_message_code == 'P' && !prepared_statements_enabled { diff --git a/src/cmd_args.rs b/src/cmd_args.rs index 3989d67..1abb7ed 100644 --- a/src/cmd_args.rs +++ b/src/cmd_args.rs @@ -25,7 +25,7 @@ pub struct Args { } pub fn parse() -> Args { - return Args::parse(); + Args::parse() } #[derive(ValueEnum, Clone, Debug)] diff --git a/src/config.rs b/src/config.rs index 9228b9b..b369d54 100644 --- a/src/config.rs +++ b/src/config.rs @@ -217,19 +217,15 @@ impl Default for User { impl User { fn validate(&self) -> Result<(), Error> { - match self.min_pool_size { - Some(min_pool_size) => { - if min_pool_size > self.pool_size { - error!( - "min_pool_size of {} cannot be larger than pool_size of {}", - min_pool_size, self.pool_size - ); - return Err(Error::BadConfig); - } + if let Some(min_pool_size) = self.min_pool_size { + if min_pool_size > self.pool_size { + error!( + "min_pool_size of {} cannot be larger than pool_size of {}", + min_pool_size, self.pool_size + ); + return Err(Error::BadConfig); } - - None => (), - }; + } Ok(()) } @@ -631,9 +627,9 @@ impl Pool { Some(key) => { // No quotes in the key so we don't have to compare quoted // to unquoted idents. - let key = key.replace("\"", ""); + let key = key.replace('\"', ""); - if key.split(".").count() != 2 { + if key.split('.').count() != 2 { error!( "automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`", key, key @@ -646,7 +642,7 @@ impl Pool { None => None, }; - for (_, user) in &self.users { + for user in self.users.values() { user.validate()?; } @@ -818,8 +814,8 @@ pub struct Query { impl Query { pub fn substitute(&mut self, db: &str, user: &str) { for col in self.result.iter_mut() { - for i in 0..col.len() { - col[i] = col[i].replace("${USER}", user).replace("${DATABASE}", db); + for c in col { + *c = c.replace("${USER}", user).replace("${DATABASE}", db); } } } @@ -929,8 +925,8 @@ impl From<&Config> for std::collections::HashMap { ( format!("pools.{:?}.users", pool_name), pool.users - .iter() - .map(|(_username, user)| &user.username) + .values() + .map(|user| &user.username) .cloned() .collect::>() .join(", "), @@ -1015,13 +1011,9 @@ impl Config { Some(tls_certificate) => { info!("TLS certificate: {}", tls_certificate); - match self.general.tls_private_key.clone() { - Some(tls_private_key) => { - info!("TLS private key: {}", tls_private_key); - info!("TLS support is enabled"); - } - - None => (), + if let Some(tls_private_key) = self.general.tls_private_key.clone() { + info!("TLS private key: {}", tls_private_key); + info!("TLS support is enabled"); } } @@ -1056,8 +1048,8 @@ impl Config { pool_name, pool_config .users - .iter() - .map(|(_, user_cfg)| user_cfg.pool_size) + .values() + .map(|user_cfg| user_cfg.pool_size) .sum::() .to_string() ); @@ -1214,35 +1206,32 @@ impl Config { } // Validate TLS! - match self.general.tls_certificate.clone() { - Some(tls_certificate) => { - match load_certs(Path::new(&tls_certificate)) { - Ok(_) => { - // Cert is okay, but what about the private key? - match self.general.tls_private_key.clone() { - Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) { - Ok(_) => (), - Err(err) => { - error!("tls_private_key is incorrectly configured: {:?}", err); - return Err(Error::BadConfig); - } - }, - - None => { - error!("tls_certificate is set, but the tls_private_key is not"); + if let Some(tls_certificate) = self.general.tls_certificate.clone() { + match load_certs(Path::new(&tls_certificate)) { + Ok(_) => { + // Cert is okay, but what about the private key? + match self.general.tls_private_key.clone() { + Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) { + Ok(_) => (), + Err(err) => { + error!("tls_private_key is incorrectly configured: {:?}", err); return Err(Error::BadConfig); } - }; - } + }, - Err(err) => { - error!("tls_certificate is incorrectly configured: {:?}", err); - return Err(Error::BadConfig); - } + None => { + error!("tls_certificate is set, but the tls_private_key is not"); + return Err(Error::BadConfig); + } + }; + } + + Err(err) => { + error!("tls_certificate is incorrectly configured: {:?}", err); + return Err(Error::BadConfig); } } - None => (), - }; + } for pool in self.pools.values_mut() { pool.validate()?; diff --git a/src/messages.rs b/src/messages.rs index 8ebc00a..c85f055 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -157,7 +157,7 @@ where match stream.write_all(&startup).await { Ok(_) => Ok(()), Err(err) => { - return Err(Error::SocketError(format!( + Err(Error::SocketError(format!( "Error writing startup to server socket - Error: {:?}", err ))) @@ -237,8 +237,8 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec { let mut md5 = Md5::new(); // First pass - md5.update(&password.as_bytes()); - md5.update(&user.as_bytes()); + md5.update(password.as_bytes()); + md5.update(user.as_bytes()); let output = md5.finalize_reset(); @@ -274,7 +274,7 @@ where { let password = md5_hash_password(user, password, salt); - let mut message = BytesMut::with_capacity(password.len() as usize + 5); + let mut message = BytesMut::with_capacity(password.len() + 5); message.put_u8(b'p'); message.put_i32(password.len() as i32 + 4); @@ -288,7 +288,7 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin, { let password = md5_hash_second_pass(hash, salt); - let mut message = BytesMut::with_capacity(password.len() as usize + 5); + let mut message = BytesMut::with_capacity(password.len() + 5); message.put_u8(b'p'); message.put_i32(password.len() as i32 + 4); @@ -509,7 +509,7 @@ pub fn data_row_nullable(row: &Vec>) -> BytesMut { data_row.put_i32(column.len() as i32); data_row.put_slice(column); } else { - data_row.put_i32(-1 as i32); + data_row.put_i32(-1_i32); } } @@ -565,7 +565,7 @@ where match stream.write_all(&buf).await { Ok(_) => Ok(()), Err(err) => { - return Err(Error::SocketError(format!( + Err(Error::SocketError(format!( "Error writing to socket - Error: {:?}", err ))) @@ -581,7 +581,7 @@ where match stream.write_all(buf).await { Ok(_) => Ok(()), Err(err) => { - return Err(Error::SocketError(format!( + Err(Error::SocketError(format!( "Error writing to socket - Error: {:?}", err ))) @@ -597,14 +597,14 @@ where Ok(_) => match stream.flush().await { Ok(_) => Ok(()), Err(err) => { - return Err(Error::SocketError(format!( + Err(Error::SocketError(format!( "Error flushing socket - Error: {:?}", err ))) } }, Err(err) => { - return Err(Error::SocketError(format!( + Err(Error::SocketError(format!( "Error writing to socket - Error: {:?}", err ))) @@ -723,7 +723,7 @@ impl BytesMutReader for Cursor<&BytesMut> { let mut buf = vec![]; match self.read_until(b'\0', &mut buf) { Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()), - Err(err) => return Err(Error::ParseBytesError(err.to_string())), + Err(err) => Err(Error::ParseBytesError(err.to_string())), } } } diff --git a/src/mirrors.rs b/src/mirrors.rs index 0f2b02c..24c6889 100644 --- a/src/mirrors.rs +++ b/src/mirrors.rs @@ -142,12 +142,12 @@ impl MirroringManager { }); Self { - byte_senders: byte_senders, + byte_senders, disconnect_senders: exit_senders, } } - pub fn send(self: &mut Self, bytes: &BytesMut) { + pub fn send(&mut self, bytes: &BytesMut) { // We want to avoid performing an allocation if we won't be able to send the message // There is a possibility of a race here where we check the capacity and then the channel is // closed or the capacity is reduced to 0, but mirroring is best effort anyway @@ -169,7 +169,7 @@ impl MirroringManager { }); } - pub fn disconnect(self: &mut Self) { + pub fn disconnect(&mut self) { self.disconnect_senders .iter_mut() .for_each(|sender| match sender.try_send(()) { diff --git a/src/plugins/intercept.rs b/src/plugins/intercept.rs index 166294b..d13ab07 100644 --- a/src/plugins/intercept.rs +++ b/src/plugins/intercept.rs @@ -92,7 +92,7 @@ impl<'a> Plugin for Intercept<'a> { .map(|s| { let s = s.as_str().to_string(); - if s == "" { + if s.is_empty() { None } else { Some(s) diff --git a/src/plugins/mod.rs b/src/plugins/mod.rs index 5ef6009..b30ef29 100644 --- a/src/plugins/mod.rs +++ b/src/plugins/mod.rs @@ -30,6 +30,7 @@ pub enum PluginOutput { Intercept(BytesMut), } +#[allow(clippy::ptr_arg)] #[async_trait] pub trait Plugin { // Run before the query is sent to the server. diff --git a/src/plugins/prewarmer.rs b/src/plugins/prewarmer.rs index a09bbe9..cd93db9 100644 --- a/src/plugins/prewarmer.rs +++ b/src/plugins/prewarmer.rs @@ -20,7 +20,7 @@ impl<'a> Prewarmer<'a> { self.server.address(), query ); - self.server.query(&query).await?; + self.server.query(query).await?; } Ok(()) diff --git a/src/plugins/table_access.rs b/src/plugins/table_access.rs index 79c1260..b8153b5 100644 --- a/src/plugins/table_access.rs +++ b/src/plugins/table_access.rs @@ -34,7 +34,7 @@ impl<'a> Plugin for TableAccess<'a> { visit_relations(ast, |relation| { let relation = relation.to_string(); - let parts = relation.split(".").collect::>(); + let parts = relation.split('.').collect::>(); let table_name = parts.last().unwrap(); if self.tables.contains(&table_name.to_string()) { diff --git a/src/pool.rs b/src/pool.rs index b929352..c45643a 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -229,20 +229,17 @@ impl ConnectionPool { let old_pool_ref = get_pool(pool_name, &user.username); let identifier = PoolIdentifier::new(pool_name, &user.username); - match old_pool_ref { - Some(pool) => { - // If the pool hasn't changed, get existing reference and insert it into the new_pools. - // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). - if pool.config_hash == new_pool_hash_value { - info!( - "[pool: {}][user: {}] has not changed", - pool_name, user.username - ); - new_pools.insert(identifier.clone(), pool.clone()); - continue; - } + if let Some(pool) = old_pool_ref { + // If the pool hasn't changed, get existing reference and insert it into the new_pools. + // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). + if pool.config_hash == new_pool_hash_value { + info!( + "[pool: {}][user: {}] has not changed", + pool_name, user.username + ); + new_pools.insert(identifier.clone(), pool.clone()); + continue; } - None => (), } info!( @@ -628,7 +625,7 @@ impl ConnectionPool { let mut force_healthcheck = false; if self.is_banned(address) { - if self.try_unban(&address).await { + if self.try_unban(address).await { force_healthcheck = true; } else { debug!("Address {:?} is banned", address); @@ -748,8 +745,8 @@ impl ConnectionPool { // Don't leave a bad connection in the pool. server.mark_bad(); - self.ban(&address, BanReason::FailedHealthCheck, Some(client_info)); - return false; + self.ban(address, BanReason::FailedHealthCheck, Some(client_info)); + false } /// Ban an address (i.e. replica). It no longer will serve @@ -861,10 +858,10 @@ impl ConnectionPool { let guard = self.banlist.read(); for banlist in guard.iter() { for (address, (reason, timestamp)) in banlist.iter() { - bans.push((address.clone(), (reason.clone(), timestamp.clone()))); + bans.push((address.clone(), (reason.clone(), *timestamp))); } } - return bans; + bans } /// Get the address from the host url @@ -921,7 +918,7 @@ impl ConnectionPool { } let busy = provisioned - idle; debug!("{:?} has {:?} busy connections", address, busy); - return busy; + busy } } diff --git a/src/query_router.rs b/src/query_router.rs index 126b813..58e6315 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -67,6 +67,7 @@ static CUSTOM_SQL_REGEX_SET: OnceCell = OnceCell::new(); static CUSTOM_SQL_REGEX_LIST: OnceCell> = OnceCell::new(); /// The query router. +#[derive(Default)] pub struct QueryRouter { /// Which shard we should be talking to right now. active_shard: Option, @@ -91,7 +92,7 @@ impl QueryRouter { /// One-time initialization of regexes /// that parse our custom SQL protocol. pub fn setup() -> bool { - let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) { + let set = match RegexSet::new(CUSTOM_SQL_REGEXES) { Ok(rgx) => rgx, Err(err) => { error!("QueryRouter::setup Could not compile regex set: {:?}", err); @@ -116,15 +117,8 @@ impl QueryRouter { /// Create a new instance of the query router. /// Each client gets its own. - pub fn new() -> QueryRouter { - QueryRouter { - active_shard: None, - active_role: None, - query_parser_enabled: None, - primary_reads_enabled: None, - pool_settings: PoolSettings::default(), - placeholders: Vec::new(), - } + pub fn new() -> Self { + Self::default() } /// Pool settings can change because of a config reload. @@ -132,7 +126,7 @@ impl QueryRouter { self.pool_settings = pool_settings; } - pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings { + pub fn pool_settings(&self) -> &PoolSettings { &self.pool_settings } @@ -143,7 +137,7 @@ impl QueryRouter { let code = message_cursor.get_u8() as char; // Check for any sharding regex matches in any queries - match code as char { + match code { // For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement 'P' | 'Q' => { if self.pool_settings.shard_id_regex.is_some() @@ -397,14 +391,10 @@ impl QueryRouter { // or discard shard selection. If they point to the same shard though, // we can let them through as-is. // This is basically building a database now :) - match self.infer_shard(query) { - Some(shard) => { - self.active_shard = Some(shard); - debug!("Automatically using shard: {:?}", self.active_shard); - } - - None => (), - }; + if let Some(shard) = self.infer_shard(query) { + self.active_shard = Some(shard); + debug!("Automatically using shard: {:?}", self.active_shard); + } } None => (), @@ -576,8 +566,8 @@ impl QueryRouter { .automatic_sharding_key .as_ref() .unwrap() - .split(".") - .map(|ident| Ident::new(ident)) + .split('.') + .map(Ident::new) .collect::>(); // Sharding key must be always fully qualified @@ -593,7 +583,7 @@ impl QueryRouter { Expr::Identifier(ident) => { // Only if we're dealing with only one table // and there is no ambiguity - if &ident.value == &sharding_key[1].value { + if ident.value == sharding_key[1].value { // Sharding key is unique enough, don't worry about // table names. if &sharding_key[0].value == "*" { @@ -606,13 +596,13 @@ impl QueryRouter { // SELECT * FROM t WHERE sharding_key = 5 // Make sure the table name from the sharding key matches // the table name from the query. - found = &sharding_key[0].value == &table[0].value; + found = sharding_key[0].value == table[0].value; } else if table.len() == 2 { // Table name is fully qualified with the schema: e.g. // SELECT * FROM public.t WHERE sharding_key = 5 // Ignore the schema (TODO: at some point, we want schema support) // and use the table name only. - found = &sharding_key[0].value == &table[1].value; + found = sharding_key[0].value == table[1].value; } else { debug!("Got table name with more than two idents, which is not possible"); } @@ -624,8 +614,8 @@ impl QueryRouter { // The key is fully qualified in the query, // it will exist or Postgres will throw an error. if idents.len() == 2 { - found = &sharding_key[0].value == &idents[0].value - && &sharding_key[1].value == &idents[1].value; + found = sharding_key[0].value == idents[0].value + && sharding_key[1].value == idents[1].value; } // TODO: key can have schema as well, e.g. public.data.id (len == 3) } @@ -657,7 +647,7 @@ impl QueryRouter { } Expr::Value(Value::Placeholder(placeholder)) => { - match placeholder.replace("$", "").parse::() { + match placeholder.replace('$', "").parse::() { Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)), Err(_) => { debug!( @@ -683,12 +673,9 @@ impl QueryRouter { match &*query.body { SetExpr::Query(query) => { - match self.infer_shard(&*query) { - Some(shard) => { - shards.insert(shard); - } - None => (), - }; + if let Some(shard) = self.infer_shard(query) { + shards.insert(shard); + } } // SELECT * FROM ... @@ -698,38 +685,22 @@ impl QueryRouter { let mut table_names = Vec::new(); for table in select.from.iter() { - match &table.relation { - TableFactor::Table { name, .. } => { - table_names.push(name.0.clone()); - } - - _ => (), - }; + if let TableFactor::Table { name, .. } = &table.relation { + table_names.push(name.0.clone()); + } // Get table names from all the joins. for join in table.joins.iter() { - match &join.relation { - TableFactor::Table { name, .. } => { - table_names.push(name.0.clone()); - } - - _ => (), - }; + if let TableFactor::Table { name, .. } = &join.relation { + table_names.push(name.0.clone()); + } // We can filter results based on join conditions, e.g. // SELECT * FROM t INNER JOIN B ON B.sharding_key = 5; - match &join.join_operator { - JoinOperator::Inner(inner_join) => match &inner_join { - JoinConstraint::On(expr) => { - // Parse the selection criteria later. - exprs.push(expr.clone()); - } - - _ => (), - }, - - _ => (), - }; + if let JoinOperator::Inner(JoinConstraint::On(expr)) = &join.join_operator { + // Parse the selection criteria later. + exprs.push(expr.clone()); + } } } @@ -803,16 +774,16 @@ impl QueryRouter { db: &self.pool_settings.db, }; - let _ = query_logger.run(&self, ast).await; + let _ = query_logger.run(self, ast).await; } if let Some(ref intercept) = plugins.intercept { let mut intercept = Intercept { enabled: intercept.enabled, - config: &intercept, + config: intercept, }; - let result = intercept.run(&self, ast).await; + let result = intercept.run(self, ast).await; if let Ok(PluginOutput::Intercept(output)) = result { return Ok(PluginOutput::Intercept(output)); @@ -825,7 +796,7 @@ impl QueryRouter { tables: &table_access.tables, }; - let result = table_access.run(&self, ast).await; + let result = table_access.run(self, ast).await; if let Ok(PluginOutput::Deny(error)) = result { return Ok(PluginOutput::Deny(error)); @@ -861,7 +832,7 @@ impl QueryRouter { /// Should we attempt to parse queries? pub fn query_parser_enabled(&self) -> bool { - let enabled = match self.query_parser_enabled { + match self.query_parser_enabled { None => { debug!( "Using pool settings, query_parser_enabled: {}", @@ -877,9 +848,7 @@ impl QueryRouter { ); value } - }; - - enabled + } } pub fn primary_reads_enabled(&self) -> bool { @@ -910,10 +879,14 @@ mod test { fn test_infer_replica() { QueryRouter::setup(); let mut qr = QueryRouter::new(); - assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None); + assert!(qr + .try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) + .is_some()); assert!(qr.query_parser_enabled()); - assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); + assert!(qr + .try_execute_command(&simple_query("SET PRIMARY READS TO off")) + .is_some()); let queries = vec![ simple_query("SELECT * FROM items WHERE id = 5"), @@ -954,7 +927,9 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); let query = simple_query("SELECT * FROM items WHERE id = 5"); - assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None); + assert!(qr + .try_execute_command(&simple_query("SET PRIMARY READS TO on")) + .is_some()); assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); assert_eq!(qr.role(), None); @@ -965,7 +940,9 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")); - assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); + assert!(qr + .try_execute_command(&simple_query("SET PRIMARY READS TO off")) + .is_some()); let prepared_stmt = BytesMut::from( &b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..], @@ -1133,9 +1110,11 @@ mod test { QueryRouter::setup(); let mut qr = QueryRouter::new(); let query = simple_query("SET SERVER ROLE TO 'auto'"); - assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); + assert!(qr + .try_execute_command(&simple_query("SET PRIMARY READS TO off")) + .is_some()); - assert!(qr.try_execute_command(&query) != None); + assert!(qr.try_execute_command(&query).is_some()); assert!(qr.query_parser_enabled()); assert_eq!(qr.role(), None); @@ -1149,7 +1128,7 @@ mod test { assert!(qr.query_parser_enabled()); let query = simple_query("SET SERVER ROLE TO 'default'"); - assert!(qr.try_execute_command(&query) != None); + assert!(qr.try_execute_command(&query).is_some()); assert!(!qr.query_parser_enabled()); } @@ -1194,11 +1173,11 @@ mod test { assert!(!qr.primary_reads_enabled()); let q1 = simple_query("SET SERVER ROLE TO 'primary'"); - assert!(qr.try_execute_command(&q1) != None); + assert!(qr.try_execute_command(&q1).is_some()); assert_eq!(qr.active_role.unwrap(), Role::Primary); let q2 = simple_query("SET SERVER ROLE TO 'default'"); - assert!(qr.try_execute_command(&q2) != None); + assert!(qr.try_execute_command(&q2).is_some()); assert_eq!(qr.active_role.unwrap(), pool_settings.default_role); } @@ -1263,17 +1242,17 @@ mod test { // Make sure setting it works let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;"); - assert!(qr.try_execute_command(&q1) == None); + assert!(qr.try_execute_command(&q1).is_none()); assert_eq!(qr.active_shard, Some(1)); // And make sure changing it works let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;"); - assert!(qr.try_execute_command(&q2) == None); + assert!(qr.try_execute_command(&q2).is_none()); assert_eq!(qr.active_shard, Some(0)); // Validate setting by shard with expected shard copied from sharding.rs tests let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;"); - assert!(qr.try_execute_command(&q2) == None); + assert!(qr.try_execute_command(&q2).is_none()); assert_eq!(qr.active_shard, Some(2)); } @@ -1411,9 +1390,11 @@ mod test { }; QueryRouter::setup(); - let mut pool_settings = PoolSettings::default(); - pool_settings.query_parser_enabled = true; - pool_settings.plugins = Some(plugins); + let pool_settings = PoolSettings { + query_parser_enabled: true, + plugins: Some(plugins), + ..Default::default() + }; let mut qr = QueryRouter::new(); qr.update_pool_settings(pool_settings); diff --git a/src/scram.rs b/src/scram.rs index 3e5d847..111dd5e 100644 --- a/src/scram.rs +++ b/src/scram.rs @@ -79,12 +79,12 @@ impl ScramSha256 { let server_message = Message::parse(message)?; if !server_message.nonce.starts_with(&self.nonce) { - return Err(Error::ProtocolSyncError(format!("SCRAM"))); + return Err(Error::ProtocolSyncError("SCRAM".to_string())); } let salt = match general_purpose::STANDARD.decode(&server_message.salt) { Ok(salt) => salt, - Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), + Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())), }; let salted_password = Self::hi( @@ -166,9 +166,9 @@ impl ScramSha256 { pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> { let final_message = FinalMessage::parse(message)?; - let verifier = match general_purpose::STANDARD.decode(&final_message.value) { + let verifier = match general_purpose::STANDARD.decode(final_message.value) { Ok(verifier) => verifier, - Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), + Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())), }; let mut hmac = match Hmac::::new_from_slice(&self.salted_password) { @@ -230,14 +230,14 @@ impl Message { .collect::>(); if parts.len() != 3 { - return Err(Error::ProtocolSyncError(format!("SCRAM"))); + return Err(Error::ProtocolSyncError("SCRAM".to_string())); } let nonce = str::replace(&parts[0], "r=", ""); let salt = str::replace(&parts[1], "s=", ""); let iterations = match str::replace(&parts[2], "i=", "").parse::() { Ok(iterations) => iterations, - Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), + Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())), }; Ok(Message { @@ -257,7 +257,7 @@ impl FinalMessage { /// Parse the server final validation message. pub fn parse(message: &BytesMut) -> Result { if !message.starts_with(b"v=") || message.len() < 4 { - return Err(Error::ProtocolSyncError(format!("SCRAM"))); + return Err(Error::ProtocolSyncError("SCRAM".to_string())); } Ok(FinalMessage { diff --git a/src/server.rs b/src/server.rs index 55444fb..f18c487 100644 --- a/src/server.rs +++ b/src/server.rs @@ -316,10 +316,7 @@ impl Server { // Something else? m => { - return Err(Error::SocketError(format!( - "Unknown message: {}", - m as char - ))); + return Err(Error::SocketError(format!("Unknown message: {}", { m }))); } } } else { @@ -337,27 +334,18 @@ impl Server { None => &user.username, }; - let password = match user.server_password { - Some(ref server_password) => Some(server_password), - None => match user.password { - Some(ref password) => Some(password), - None => None, - }, - }; + let password = user.server_password.as_ref(); startup(&mut stream, username, database).await?; let mut server_info = BytesMut::new(); let mut process_id: i32 = 0; let mut secret_key: i32 = 0; - let server_identifier = ServerIdentifier::new(username, &database); + let server_identifier = ServerIdentifier::new(username, database); // We'll be handling multiple packets, but they will all be structured the same. // We'll loop here until this exchange is complete. - let mut scram: Option = match password { - Some(password) => Some(ScramSha256::new(password)), - None => None, - }; + let mut scram: Option = password.map(|password| ScramSha256::new(password)); loop { let code = match stream.read_u8().await { @@ -753,7 +741,7 @@ impl Server { self.mirror_send(messages); self.stats().data_sent(messages.len()); - match write_all_flush(&mut self.stream, &messages).await { + match write_all_flush(&mut self.stream, messages).await { Ok(_) => { // Successfully sent to server self.last_activity = SystemTime::now(); @@ -1199,16 +1187,14 @@ impl Server { } pub fn mirror_send(&mut self, bytes: &BytesMut) { - match self.mirror_manager.as_mut() { - Some(manager) => manager.send(bytes), - None => (), + if let Some(manager) = self.mirror_manager.as_mut() { + manager.send(bytes); } } pub fn mirror_disconnect(&mut self) { - match self.mirror_manager.as_mut() { - Some(manager) => manager.disconnect(), - None => (), + if let Some(manager) = self.mirror_manager.as_mut() { + manager.disconnect(); } } @@ -1236,7 +1222,7 @@ impl Server { server.send(&simple_query(query)).await?; let mut message = server.recv().await?; - Ok(parse_query_message(&mut message).await?) + parse_query_message(&mut message).await } } diff --git a/src/sharding.rs b/src/sharding.rs index 18581dc..7aa7f36 100644 --- a/src/sharding.rs +++ b/src/sharding.rs @@ -64,7 +64,7 @@ impl Sharder { fn sha1(&self, key: i64) -> usize { let mut hasher = Sha1::new(); - hasher.update(&key.to_string().as_bytes()); + hasher.update(key.to_string().as_bytes()); let result = hasher.finalize(); diff --git a/src/stats/pool.rs b/src/stats/pool.rs index d3ac78e..46c7463 100644 --- a/src/stats/pool.rs +++ b/src/stats/pool.rs @@ -86,11 +86,11 @@ impl PoolStats { } } - return map; + map } pub fn generate_header() -> Vec<(&'static str, DataType)> { - return vec![ + vec![ ("database", DataType::Text), ("user", DataType::Text), ("pool_mode", DataType::Text), @@ -105,11 +105,11 @@ impl PoolStats { ("sv_login", DataType::Numeric), ("maxwait", DataType::Numeric), ("maxwait_us", DataType::Numeric), - ]; + ] } pub fn generate_row(&self) -> Vec { - return vec![ + vec![ self.identifier.db.clone(), self.identifier.user.clone(), self.mode.to_string(), @@ -124,7 +124,7 @@ impl PoolStats { self.sv_login.to_string(), (self.maxwait / 1_000_000).to_string(), (self.maxwait % 1_000_000).to_string(), - ]; + ] } }