Added clippy to CI and fixed all clippy warnings (#613)

* Fixed all clippy warnings.

* Added `clippy` to CI.

* Reverted an unwanted change + Applied `cargo fmt`.

* Fixed the idiom version.

* Revert "Fixed the idiom version."

This reverts commit 6f78be0d42.

* Fixed clippy issues on CI.

* Revert "Fixed clippy issues on CI."

This reverts commit a9fa6ba189.

* Revert "Reverted an unwanted change + Applied `cargo fmt`."

This reverts commit 6bd37b6479.

* Revert "Fixed all clippy warnings."

This reverts commit d1f3b847e3.

* Removed Clippy

* Removed Lint

* `admin.rs` clippy fixes.

* Applied more clippy changes.

* Even more clippy changes.

* `client.rs` clippy fixes.

* `server.rs` clippy fixes.

* Revert "Removed Lint"

This reverts commit cb5042b144.

* Revert "Removed Clippy"

This reverts commit 6dec8bffb1.

* Applied lint.

* Revert "Revert "Fixed clippy issues on CI.""

This reverts commit 49164a733c.
This commit is contained in:
Mohammad Dashti
2023-10-10 09:18:21 -07:00
committed by GitHub
parent c4fb72b9fc
commit de8df29ca4
18 changed files with 258 additions and 304 deletions

View File

@@ -63,6 +63,9 @@ jobs:
- run: - run:
name: "Lint" name: "Lint"
command: "cargo fmt --check" command: "cargo fmt --check"
- run:
name: "Clippy"
command: "cargo clippy --all --all-targets -- -Dwarnings"
- run: - run:
name: "Tests" name: "Tests"
command: "cargo clean && cargo build && cargo test && bash .circleci/run_tests.sh && .circleci/generate_coverage.sh" command: "cargo clean && cargo build && cargo test && bash .circleci/run_tests.sh && .circleci/generate_coverage.sh"

View File

@@ -2,7 +2,7 @@
Thank you for contributing! Just a few tips here: Thank you for contributing! Just a few tips here:
1. `cargo fmt` your code before opening up a PR 1. `cargo fmt` and `cargo clippy` your code before opening up a PR
2. Run the test suite (e.g. `pgbench`) to make sure everything still works. The tests are in `.circleci/run_tests.sh`. 2. Run the test suite (e.g. `pgbench`) to make sure everything still works. The tests are in `.circleci/run_tests.sh`.
3. Performance is important, make sure there are no regressions in your branch vs. `main`. 3. Performance is important, make sure there are no regressions in your branch vs. `main`.

View File

@@ -283,7 +283,7 @@ where
{ {
let mut res = BytesMut::new(); let mut res = BytesMut::new();
let detail_msg = vec![ let detail_msg = [
"", "",
"SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION", "SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION",
// "SHOW PEERS|PEER_POOLS", // missing PEERS|PEER_POOLS // "SHOW PEERS|PEER_POOLS", // missing PEERS|PEER_POOLS
@@ -301,7 +301,6 @@ where
// "KILL <db>", // "KILL <db>",
// "SUSPEND", // "SUSPEND",
"SHUTDOWN", "SHUTDOWN",
// "WAIT_CLOSE [<db>]", // missing
]; ];
res.put(notify("Console usage", detail_msg.join("\n\t"))); res.put(notify("Console usage", detail_msg.join("\n\t")));
@@ -802,7 +801,7 @@ where
T: tokio::io::AsyncWrite + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let parts: Vec<&str> = match tokens.len() == 2 { let parts: Vec<&str> = match tokens.len() == 2 {
true => tokens[1].split(",").map(|part| part.trim()).collect(), true => tokens[1].split(',').map(|part| part.trim()).collect(),
false => Vec::new(), false => Vec::new(),
}; };
@@ -865,7 +864,7 @@ where
T: tokio::io::AsyncWrite + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let parts: Vec<&str> = match tokens.len() == 2 { let parts: Vec<&str> = match tokens.len() == 2 {
true => tokens[1].split(",").map(|part| part.trim()).collect(), true => tokens[1].split(',').map(|part| part.trim()).collect(),
false => Vec::new(), false => Vec::new(),
}; };

View File

@@ -131,7 +131,7 @@ pub async fn client_entrypoint(
// Client requested a TLS connection. // Client requested a TLS connection.
Ok((ClientConnectionType::Tls, _)) => { Ok((ClientConnectionType::Tls, _)) => {
// TLS settings are configured, will setup TLS now. // TLS settings are configured, will setup TLS now.
if tls_certificate != None { if tls_certificate.is_some() {
debug!("Accepting TLS request"); debug!("Accepting TLS request");
let mut yes = BytesMut::new(); let mut yes = BytesMut::new();
@@ -448,7 +448,7 @@ where
None => "pgcat", None => "pgcat",
}; };
let client_identifier = ClientIdentifier::new(&application_name, &username, &pool_name); let client_identifier = ClientIdentifier::new(application_name, username, pool_name);
let admin = ["pgcat", "pgbouncer"] let admin = ["pgcat", "pgbouncer"]
.iter() .iter()
@@ -795,7 +795,7 @@ where
let mut will_prepare = false; let mut will_prepare = false;
let client_identifier = ClientIdentifier::new( let client_identifier = ClientIdentifier::new(
&self.server_parameters.get_application_name(), self.server_parameters.get_application_name(),
&self.username, &self.username,
&self.pool_name, &self.pool_name,
); );
@@ -982,15 +982,11 @@ where
} }
// Check on plugin results. // Check on plugin results.
match plugin_output { if let Some(PluginOutput::Deny(error)) = plugin_output {
Some(PluginOutput::Deny(error)) => { self.buffer.clear();
self.buffer.clear(); error_response(&mut self.write, &error).await?;
error_response(&mut self.write, &error).await?; plugin_output = None;
plugin_output = None; continue;
continue;
}
_ => (),
}; };
// Check if the pool is paused and wait until it's resumed. // Check if the pool is paused and wait until it's resumed.
@@ -1267,7 +1263,7 @@ where
// Safe to unwrap because we know this message has a certain length and has the code // Safe to unwrap because we know this message has a certain length and has the code
// This reads the first byte without advancing the internal pointer and mutating the bytes // This reads the first byte without advancing the internal pointer and mutating the bytes
let code = *message.get(0).unwrap() as char; let code = *message.first().unwrap() as char;
trace!("Message: {}", code); trace!("Message: {}", code);
@@ -1325,7 +1321,7 @@ where
self.stats.transaction(); self.stats.transaction();
server server
.stats() .stats()
.transaction(&self.server_parameters.get_application_name()); .transaction(self.server_parameters.get_application_name());
// Release server back to the pool if we are in transaction mode. // Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects. // If we are in session mode, we keep the server until the client disconnects.
@@ -1400,13 +1396,10 @@ where
let close: Close = (&message).try_into()?; let close: Close = (&message).try_into()?;
if close.is_prepared_statement() && !close.anonymous() { if close.is_prepared_statement() && !close.anonymous() {
match self.prepared_statements.get(&close.name) { if let Some(parse) = self.prepared_statements.get(&close.name) {
Some(parse) => { server.will_close(&parse.generated_name);
server.will_close(&parse.generated_name); } else {
}
// A prepared statement slipped through? Not impossible, since we don't support PREPARE yet. // A prepared statement slipped through? Not impossible, since we don't support PREPARE yet.
None => (),
}; };
} }
} }
@@ -1445,7 +1438,7 @@ where
self.buffer.put(&message[..]); self.buffer.put(&message[..]);
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; let first_message_code = (*self.buffer.first().unwrap_or(&0)) as char;
// Almost certainly true // Almost certainly true
if first_message_code == 'P' && !prepared_statements_enabled { if first_message_code == 'P' && !prepared_statements_enabled {
@@ -1477,7 +1470,7 @@ where
self.stats.transaction(); self.stats.transaction();
server server
.stats() .stats()
.transaction(&self.server_parameters.get_application_name()); .transaction(self.server_parameters.get_application_name());
// Release server back to the pool if we are in transaction mode. // Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects. // If we are in session mode, we keep the server until the client disconnects.
@@ -1739,7 +1732,7 @@ where
client_stats.query(); client_stats.query();
server.stats().query( server.stats().query(
Instant::now().duration_since(query_start).as_millis() as u64, Instant::now().duration_since(query_start).as_millis() as u64,
&self.server_parameters.get_application_name(), self.server_parameters.get_application_name(),
); );
Ok(()) Ok(())

View File

@@ -25,7 +25,7 @@ pub struct Args {
} }
pub fn parse() -> Args { pub fn parse() -> Args {
return Args::parse(); Args::parse()
} }
#[derive(ValueEnum, Clone, Debug)] #[derive(ValueEnum, Clone, Debug)]

View File

@@ -236,18 +236,14 @@ impl Default for User {
impl User { impl User {
fn validate(&self) -> Result<(), Error> { fn validate(&self) -> Result<(), Error> {
match self.min_pool_size { if let Some(min_pool_size) = self.min_pool_size {
Some(min_pool_size) => { if min_pool_size > self.pool_size {
if min_pool_size > self.pool_size { error!(
error!( "min_pool_size of {} cannot be larger than pool_size of {}",
"min_pool_size of {} cannot be larger than pool_size of {}", min_pool_size, self.pool_size
min_pool_size, self.pool_size );
); return Err(Error::BadConfig);
return Err(Error::BadConfig);
}
} }
None => (),
}; };
Ok(()) Ok(())
@@ -677,9 +673,9 @@ impl Pool {
Some(key) => { Some(key) => {
// No quotes in the key so we don't have to compare quoted // No quotes in the key so we don't have to compare quoted
// to unquoted idents. // to unquoted idents.
let key = key.replace("\"", ""); let key = key.replace('\"', "");
if key.split(".").count() != 2 { if key.split('.').count() != 2 {
error!( error!(
"automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`", "automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`",
key, key key, key
@@ -692,17 +688,14 @@ impl Pool {
None => None, None => None,
}; };
match self.default_shard { if let DefaultShard::Shard(shard_number) = self.default_shard {
DefaultShard::Shard(shard_number) => { if shard_number >= self.shards.len() {
if shard_number >= self.shards.len() { error!("Invalid shard {:?}", shard_number);
error!("Invalid shard {:?}", shard_number); return Err(Error::BadConfig);
return Err(Error::BadConfig);
}
} }
_ => (),
} }
for (_, user) in &self.users { for user in self.users.values() {
user.validate()?; user.validate()?;
} }
@@ -777,8 +770,8 @@ impl<'de> serde::Deserialize<'de> for DefaultShard {
D: Deserializer<'de>, D: Deserializer<'de>,
{ {
let s = String::deserialize(deserializer)?; let s = String::deserialize(deserializer)?;
if s.starts_with("shard_") { if let Some(s) = s.strip_prefix("shard_") {
let shard = s[6..].parse::<usize>().map_err(serde::de::Error::custom)?; let shard = s.parse::<usize>().map_err(serde::de::Error::custom)?;
return Ok(DefaultShard::Shard(shard)); return Ok(DefaultShard::Shard(shard));
} }
@@ -874,7 +867,7 @@ pub trait Plugin {
impl std::fmt::Display for Plugins { impl std::fmt::Display for Plugins {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
fn is_enabled<T: Plugin>(arg: Option<&T>) -> bool { fn is_enabled<T: Plugin>(arg: Option<&T>) -> bool {
if let Some(ref arg) = arg { if let Some(arg) = arg {
arg.is_enabled() arg.is_enabled()
} else { } else {
false false
@@ -955,6 +948,7 @@ pub struct Query {
} }
impl Query { impl Query {
#[allow(clippy::needless_range_loop)]
pub fn substitute(&mut self, db: &str, user: &str) { pub fn substitute(&mut self, db: &str, user: &str) {
for col in self.result.iter_mut() { for col in self.result.iter_mut() {
for i in 0..col.len() { for i in 0..col.len() {
@@ -1079,8 +1073,8 @@ impl From<&Config> for std::collections::HashMap<String, String> {
( (
format!("pools.{:?}.users", pool_name), format!("pools.{:?}.users", pool_name),
pool.users pool.users
.iter() .values()
.map(|(_username, user)| &user.username) .map(|user| &user.username)
.cloned() .cloned()
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(", "), .join(", "),
@@ -1165,13 +1159,9 @@ impl Config {
Some(tls_certificate) => { Some(tls_certificate) => {
info!("TLS certificate: {}", tls_certificate); info!("TLS certificate: {}", tls_certificate);
match self.general.tls_private_key.clone() { if let Some(tls_private_key) = self.general.tls_private_key.clone() {
Some(tls_private_key) => { info!("TLS private key: {}", tls_private_key);
info!("TLS private key: {}", tls_private_key); info!("TLS support is enabled");
info!("TLS support is enabled");
}
None => (),
} }
} }
@@ -1206,8 +1196,8 @@ impl Config {
pool_name, pool_name,
pool_config pool_config
.users .users
.iter() .values()
.map(|(_, user_cfg)| user_cfg.pool_size) .map(|user_cfg| user_cfg.pool_size)
.sum::<u32>() .sum::<u32>()
.to_string() .to_string()
); );
@@ -1377,34 +1367,31 @@ impl Config {
} }
// Validate TLS! // Validate TLS!
match self.general.tls_certificate.clone() { if let Some(tls_certificate) = self.general.tls_certificate.clone() {
Some(tls_certificate) => { match load_certs(Path::new(&tls_certificate)) {
match load_certs(Path::new(&tls_certificate)) { Ok(_) => {
Ok(_) => { // Cert is okay, but what about the private key?
// Cert is okay, but what about the private key? match self.general.tls_private_key.clone() {
match self.general.tls_private_key.clone() { Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) {
Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) { Ok(_) => (),
Ok(_) => (), Err(err) => {
Err(err) => { error!("tls_private_key is incorrectly configured: {:?}", err);
error!("tls_private_key is incorrectly configured: {:?}", err);
return Err(Error::BadConfig);
}
},
None => {
error!("tls_certificate is set, but the tls_private_key is not");
return Err(Error::BadConfig); return Err(Error::BadConfig);
} }
}; },
}
Err(err) => { None => {
error!("tls_certificate is incorrectly configured: {:?}", err); error!("tls_certificate is set, but the tls_private_key is not");
return Err(Error::BadConfig); return Err(Error::BadConfig);
} }
};
}
Err(err) => {
error!("tls_certificate is incorrectly configured: {:?}", err);
return Err(Error::BadConfig);
} }
} }
None => (),
}; };
for pool in self.pools.values_mut() { for pool in self.pools.values_mut() {

View File

@@ -163,12 +163,10 @@ where
match stream.write_all(&startup).await { match stream.write_all(&startup).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => Err(Error::SocketError(format!(
return Err(Error::SocketError(format!( "Error writing startup to server socket - Error: {:?}",
"Error writing startup to server socket - Error: {:?}", err
err ))),
)))
}
} }
} }
@@ -244,8 +242,8 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
let mut md5 = Md5::new(); let mut md5 = Md5::new();
// First pass // First pass
md5.update(&password.as_bytes()); md5.update(password.as_bytes());
md5.update(&user.as_bytes()); md5.update(user.as_bytes());
let output = md5.finalize_reset(); let output = md5.finalize_reset();
@@ -281,7 +279,7 @@ where
{ {
let password = md5_hash_password(user, password, salt); let password = md5_hash_password(user, password, salt);
let mut message = BytesMut::with_capacity(password.len() as usize + 5); let mut message = BytesMut::with_capacity(password.len() + 5);
message.put_u8(b'p'); message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4); message.put_i32(password.len() as i32 + 4);
@@ -295,7 +293,7 @@ where
S: tokio::io::AsyncWrite + std::marker::Unpin, S: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let password = md5_hash_second_pass(hash, salt); let password = md5_hash_second_pass(hash, salt);
let mut message = BytesMut::with_capacity(password.len() as usize + 5); let mut message = BytesMut::with_capacity(password.len() + 5);
message.put_u8(b'p'); message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4); message.put_i32(password.len() as i32 + 4);
@@ -516,7 +514,7 @@ pub fn data_row_nullable(row: &Vec<Option<String>>) -> BytesMut {
data_row.put_i32(column.len() as i32); data_row.put_i32(column.len() as i32);
data_row.put_slice(column); data_row.put_slice(column);
} else { } else {
data_row.put_i32(-1 as i32); data_row.put_i32(-1_i32);
} }
} }
@@ -571,12 +569,10 @@ where
{ {
match stream.write_all(&buf).await { match stream.write_all(&buf).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => Err(Error::SocketError(format!(
return Err(Error::SocketError(format!( "Error writing to socket - Error: {:?}",
"Error writing to socket - Error: {:?}", err
err ))),
)))
}
} }
} }
@@ -587,12 +583,10 @@ where
{ {
match stream.write_all(buf).await { match stream.write_all(buf).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => Err(Error::SocketError(format!(
return Err(Error::SocketError(format!( "Error writing to socket - Error: {:?}",
"Error writing to socket - Error: {:?}", err
err ))),
)))
}
} }
} }
@@ -603,19 +597,15 @@ where
match stream.write_all(buf).await { match stream.write_all(buf).await {
Ok(_) => match stream.flush().await { Ok(_) => match stream.flush().await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => Err(Error::SocketError(format!(
return Err(Error::SocketError(format!( "Error flushing socket - Error: {:?}",
"Error flushing socket - Error: {:?}",
err
)))
}
},
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err err
))) ))),
} },
Err(err) => Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
))),
} }
} }
@@ -730,7 +720,7 @@ impl BytesMutReader for Cursor<&BytesMut> {
let mut buf = vec![]; let mut buf = vec![];
match self.read_until(b'\0', &mut buf) { match self.read_until(b'\0', &mut buf) {
Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()), Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()),
Err(err) => return Err(Error::ParseBytesError(err.to_string())), Err(err) => Err(Error::ParseBytesError(err.to_string())),
} }
} }
} }
@@ -746,7 +736,7 @@ impl BytesMutReader for BytesMut {
let string_bytes = self.split_to(index + 1); let string_bytes = self.split_to(index + 1);
Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string()) Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string())
} }
None => return Err(Error::ParseBytesError("Could not read string".to_string())), None => Err(Error::ParseBytesError("Could not read string".to_string())),
} }
} }
} }
@@ -1311,38 +1301,38 @@ mod tests {
fn parse_fields() { fn parse_fields() {
let mut complete_msg = vec![]; let mut complete_msg = vec![];
let severity = "FATAL"; let severity = "FATAL";
complete_msg.extend(field('S', &severity)); complete_msg.extend(field('S', severity));
complete_msg.extend(field('V', &severity)); complete_msg.extend(field('V', severity));
let error_code = "29P02"; let error_code = "29P02";
complete_msg.extend(field('C', &error_code)); complete_msg.extend(field('C', error_code));
let message = "password authentication failed for user \"wrong_user\""; let message = "password authentication failed for user \"wrong_user\"";
complete_msg.extend(field('M', &message)); complete_msg.extend(field('M', message));
let detail_msg = "super detailed message"; let detail_msg = "super detailed message";
complete_msg.extend(field('D', &detail_msg)); complete_msg.extend(field('D', detail_msg));
let hint_msg = "hint detail here"; let hint_msg = "hint detail here";
complete_msg.extend(field('H', &hint_msg)); complete_msg.extend(field('H', hint_msg));
complete_msg.extend(field('P', "123")); complete_msg.extend(field('P', "123"));
complete_msg.extend(field('p', "234")); complete_msg.extend(field('p', "234"));
let internal_query = "SELECT * from foo;"; let internal_query = "SELECT * from foo;";
complete_msg.extend(field('q', &internal_query)); complete_msg.extend(field('q', internal_query));
let where_msg = "where goes here"; let where_msg = "where goes here";
complete_msg.extend(field('W', &where_msg)); complete_msg.extend(field('W', where_msg));
let schema_msg = "schema_name"; let schema_msg = "schema_name";
complete_msg.extend(field('s', &schema_msg)); complete_msg.extend(field('s', schema_msg));
let table_msg = "table_name"; let table_msg = "table_name";
complete_msg.extend(field('t', &table_msg)); complete_msg.extend(field('t', table_msg));
let column_msg = "column_name"; let column_msg = "column_name";
complete_msg.extend(field('c', &column_msg)); complete_msg.extend(field('c', column_msg));
let data_type_msg = "type_name"; let data_type_msg = "type_name";
complete_msg.extend(field('d', &data_type_msg)); complete_msg.extend(field('d', data_type_msg));
let constraint_msg = "constraint_name"; let constraint_msg = "constraint_name";
complete_msg.extend(field('n', &constraint_msg)); complete_msg.extend(field('n', constraint_msg));
let file_msg = "pgcat.c"; let file_msg = "pgcat.c";
complete_msg.extend(field('F', &file_msg)); complete_msg.extend(field('F', file_msg));
complete_msg.extend(field('L', "335")); complete_msg.extend(field('L', "335"));
let routine_msg = "my_failing_routine"; let routine_msg = "my_failing_routine";
complete_msg.extend(field('R', &routine_msg)); complete_msg.extend(field('R', routine_msg));
tracing_subscriber::fmt() tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO) .with_max_level(tracing::Level::INFO)
@@ -1378,11 +1368,11 @@ mod tests {
); );
let mut only_mandatory_msg = vec![]; let mut only_mandatory_msg = vec![];
only_mandatory_msg.extend(field('S', &severity)); only_mandatory_msg.extend(field('S', severity));
only_mandatory_msg.extend(field('V', &severity)); only_mandatory_msg.extend(field('V', severity));
only_mandatory_msg.extend(field('C', &error_code)); only_mandatory_msg.extend(field('C', error_code));
only_mandatory_msg.extend(field('M', &message)); only_mandatory_msg.extend(field('M', message));
only_mandatory_msg.extend(field('D', &detail_msg)); only_mandatory_msg.extend(field('D', detail_msg));
let err_fields = PgErrorMsg::parse(only_mandatory_msg.clone()).unwrap(); let err_fields = PgErrorMsg::parse(only_mandatory_msg.clone()).unwrap();
info!("only mandatory fields: {}", &err_fields); info!("only mandatory fields: {}", &err_fields);

View File

@@ -137,18 +137,18 @@ impl MirroringManager {
bytes_rx, bytes_rx,
disconnect_rx: exit_rx, disconnect_rx: exit_rx,
}; };
exit_senders.push(exit_tx.clone()); exit_senders.push(exit_tx);
byte_senders.push(bytes_tx.clone()); byte_senders.push(bytes_tx);
client.start(); client.start();
}); });
Self { Self {
byte_senders: byte_senders, byte_senders,
disconnect_senders: exit_senders, disconnect_senders: exit_senders,
} }
} }
pub fn send(self: &mut Self, bytes: &BytesMut) { pub fn send(&mut self, bytes: &BytesMut) {
// We want to avoid performing an allocation if we won't be able to send the message // We want to avoid performing an allocation if we won't be able to send the message
// There is a possibility of a race here where we check the capacity and then the channel is // There is a possibility of a race here where we check the capacity and then the channel is
// closed or the capacity is reduced to 0, but mirroring is best effort anyway // closed or the capacity is reduced to 0, but mirroring is best effort anyway
@@ -170,7 +170,7 @@ impl MirroringManager {
}); });
} }
pub fn disconnect(self: &mut Self) { pub fn disconnect(&mut self) {
self.disconnect_senders self.disconnect_senders
.iter_mut() .iter_mut()
.for_each(|sender| match sender.try_send(()) { .for_each(|sender| match sender.try_send(()) {

View File

@@ -92,7 +92,7 @@ impl<'a> Plugin for Intercept<'a> {
.map(|s| { .map(|s| {
let s = s.as_str().to_string(); let s = s.as_str().to_string();
if s == "" { if s.is_empty() {
None None
} else { } else {
Some(s) Some(s)

View File

@@ -33,6 +33,7 @@ pub enum PluginOutput {
#[async_trait] #[async_trait]
pub trait Plugin { pub trait Plugin {
// Run before the query is sent to the server. // Run before the query is sent to the server.
#[allow(clippy::ptr_arg)]
async fn run( async fn run(
&mut self, &mut self,
query_router: &QueryRouter, query_router: &QueryRouter,

View File

@@ -20,7 +20,7 @@ impl<'a> Prewarmer<'a> {
self.server.address(), self.server.address(),
query query
); );
self.server.query(&query).await?; self.server.query(query).await?;
} }
Ok(()) Ok(())

View File

@@ -34,7 +34,7 @@ impl<'a> Plugin for TableAccess<'a> {
visit_relations(ast, |relation| { visit_relations(ast, |relation| {
let relation = relation.to_string(); let relation = relation.to_string();
let parts = relation.split(".").collect::<Vec<&str>>(); let parts = relation.split('.').collect::<Vec<&str>>();
let table_name = parts.last().unwrap(); let table_name = parts.last().unwrap();
if self.tables.contains(&table_name.to_string()) { if self.tables.contains(&table_name.to_string()) {

View File

@@ -241,20 +241,17 @@ impl ConnectionPool {
let old_pool_ref = get_pool(pool_name, &user.username); let old_pool_ref = get_pool(pool_name, &user.username);
let identifier = PoolIdentifier::new(pool_name, &user.username); let identifier = PoolIdentifier::new(pool_name, &user.username);
match old_pool_ref { if let Some(pool) = old_pool_ref {
Some(pool) => { // If the pool hasn't changed, get existing reference and insert it into the new_pools.
// If the pool hasn't changed, get existing reference and insert it into the new_pools. // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). if pool.config_hash == new_pool_hash_value {
if pool.config_hash == new_pool_hash_value { info!(
info!( "[pool: {}][user: {}] has not changed",
"[pool: {}][user: {}] has not changed", pool_name, user.username
pool_name, user.username );
); new_pools.insert(identifier.clone(), pool.clone());
new_pools.insert(identifier.clone(), pool.clone()); continue;
continue;
}
} }
None => (),
} }
info!( info!(
@@ -399,7 +396,7 @@ impl ConnectionPool {
}, },
}; };
let reaper_rate = *vec![idle_timeout, server_lifetime, POOL_REAPER_RATE] let reaper_rate = *[idle_timeout, server_lifetime, POOL_REAPER_RATE]
.iter() .iter()
.min() .min()
.unwrap(); .unwrap();
@@ -489,7 +486,7 @@ impl ConnectionPool {
.clone() .clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()), .map(|regex| Regex::new(regex.as_str()).unwrap()),
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000), regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
default_shard: pool_config.default_shard.clone(), default_shard: pool_config.default_shard,
auth_query: pool_config.auth_query.clone(), auth_query: pool_config.auth_query.clone(),
auth_query_user: pool_config.auth_query_user.clone(), auth_query_user: pool_config.auth_query_user.clone(),
auth_query_password: pool_config.auth_query_password.clone(), auth_query_password: pool_config.auth_query_password.clone(),
@@ -678,7 +675,7 @@ impl ConnectionPool {
let mut force_healthcheck = false; let mut force_healthcheck = false;
if self.is_banned(address) { if self.is_banned(address) {
if self.try_unban(&address).await { if self.try_unban(address).await {
force_healthcheck = true; force_healthcheck = true;
} else { } else {
debug!("Address {:?} is banned", address); debug!("Address {:?} is banned", address);
@@ -806,8 +803,8 @@ impl ConnectionPool {
// Don't leave a bad connection in the pool. // Don't leave a bad connection in the pool.
server.mark_bad(); server.mark_bad();
self.ban(&address, BanReason::FailedHealthCheck, Some(client_info)); self.ban(address, BanReason::FailedHealthCheck, Some(client_info));
return false; false
} }
/// Ban an address (i.e. replica). It no longer will serve /// Ban an address (i.e. replica). It no longer will serve
@@ -931,10 +928,10 @@ impl ConnectionPool {
let guard = self.banlist.read(); let guard = self.banlist.read();
for banlist in guard.iter() { for banlist in guard.iter() {
for (address, (reason, timestamp)) in banlist.iter() { for (address, (reason, timestamp)) in banlist.iter() {
bans.push((address.clone(), (reason.clone(), timestamp.clone()))); bans.push((address.clone(), (reason.clone(), *timestamp)));
} }
} }
return bans; bans
} }
/// Get the address from the host url /// Get the address from the host url
@@ -992,7 +989,7 @@ impl ConnectionPool {
} }
let busy = provisioned - idle; let busy = provisioned - idle;
debug!("{:?} has {:?} busy connections", address, busy); debug!("{:?} has {:?} busy connections", address, busy);
return busy; busy
} }
fn valid_shard_id(&self, shard: Option<usize>) -> bool { fn valid_shard_id(&self, shard: Option<usize>) -> bool {
@@ -1031,6 +1028,7 @@ pub struct ServerPool {
} }
impl ServerPool { impl ServerPool {
#[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
address: Address, address: Address,
user: User, user: User,
@@ -1043,7 +1041,7 @@ impl ServerPool {
) -> ServerPool { ) -> ServerPool {
ServerPool { ServerPool {
address, address,
user: user.clone(), user,
database: database.to_string(), database: database.to_string(),
client_server_map, client_server_map,
auth_hash, auth_hash,

View File

@@ -91,7 +91,7 @@ impl QueryRouter {
/// One-time initialization of regexes /// One-time initialization of regexes
/// that parse our custom SQL protocol. /// that parse our custom SQL protocol.
pub fn setup() -> bool { pub fn setup() -> bool {
let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) { let set = match RegexSet::new(CUSTOM_SQL_REGEXES) {
Ok(rgx) => rgx, Ok(rgx) => rgx,
Err(err) => { Err(err) => {
error!("QueryRouter::setup Could not compile regex set: {:?}", err); error!("QueryRouter::setup Could not compile regex set: {:?}", err);
@@ -132,7 +132,7 @@ impl QueryRouter {
self.pool_settings = pool_settings; self.pool_settings = pool_settings;
} }
pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings { pub fn pool_settings(&self) -> &PoolSettings {
&self.pool_settings &self.pool_settings
} }
@@ -148,7 +148,7 @@ impl QueryRouter {
// Check for any sharding regex matches in any queries // Check for any sharding regex matches in any queries
if comment_shard_routing_enabled { if comment_shard_routing_enabled {
match code as char { match code {
// For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement // For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
'P' | 'Q' => { 'P' | 'Q' => {
// Check only the first block of bytes configured by the pool settings // Check only the first block of bytes configured by the pool settings
@@ -344,16 +344,13 @@ impl QueryRouter {
let code = message_cursor.get_u8() as char; let code = message_cursor.get_u8() as char;
let len = message_cursor.get_i32() as usize; let len = message_cursor.get_i32() as usize;
match self.pool_settings.query_parser_max_length { if let Some(max_length) = self.pool_settings.query_parser_max_length {
Some(max_length) => { if len > max_length {
if len > max_length { return Err(Error::QueryRouterParserError(format!(
return Err(Error::QueryRouterParserError(format!( "Query too long for parser: {} > {}",
"Query too long for parser: {} > {}", len, max_length
len, max_length )));
)));
}
} }
None => (),
}; };
let query = match code { let query = match code {
@@ -467,22 +464,18 @@ impl QueryRouter {
inferred_shard: Option<usize>, inferred_shard: Option<usize>,
prev_inferred_shard: &mut Option<usize>, prev_inferred_shard: &mut Option<usize>,
) -> Result<(), Error> { ) -> Result<(), Error> {
match inferred_shard { if let Some(shard) = inferred_shard {
Some(shard) => { if let Some(prev_shard) = *prev_inferred_shard {
if let Some(prev_shard) = *prev_inferred_shard { if prev_shard != shard {
if prev_shard != shard { debug!("Found more than one shard in the query, not supported yet");
debug!("Found more than one shard in the query, not supported yet"); return Err(Error::QueryRouterParserError(
return Err(Error::QueryRouterParserError( "multiple shards in query".into(),
"multiple shards in query".into(), ));
));
}
} }
*prev_inferred_shard = Some(shard);
self.active_shard = Some(shard);
debug!("Automatically using shard: {:?}", self.active_shard);
} }
*prev_inferred_shard = Some(shard);
None => (), self.active_shard = Some(shard);
debug!("Automatically using shard: {:?}", self.active_shard);
}; };
Ok(()) Ok(())
} }
@@ -513,7 +506,7 @@ impl QueryRouter {
assert!(after_columns.is_empty()); assert!(after_columns.is_empty());
Self::process_table(table_name, &mut table_names); Self::process_table(table_name, &mut table_names);
Self::process_query(&*source, &mut exprs, &mut table_names, &Some(columns)); Self::process_query(source, &mut exprs, &mut table_names, &Some(columns));
} }
Delete { Delete {
tables, tables,
@@ -529,7 +522,7 @@ impl QueryRouter {
// Multi tables delete are not supported in postgres. // Multi tables delete are not supported in postgres.
assert!(tables.is_empty()); assert!(tables.is_empty());
Self::process_tables_with_join(&from, &mut exprs, &mut table_names); Self::process_tables_with_join(from, &mut exprs, &mut table_names);
if let Some(using_tbl_with_join) = using { if let Some(using_tbl_with_join) = using {
Self::process_tables_with_join( Self::process_tables_with_join(
using_tbl_with_join, using_tbl_with_join,
@@ -569,7 +562,7 @@ impl QueryRouter {
) { ) {
match &*query.body { match &*query.body {
SetExpr::Query(query) => { SetExpr::Query(query) => {
Self::process_query(&*query, exprs, table_names, columns); Self::process_query(query, exprs, table_names, columns);
} }
// SELECT * FROM ... // SELECT * FROM ...
@@ -611,7 +604,7 @@ impl QueryRouter {
} }
fn process_tables_with_join( fn process_tables_with_join(
tables: &Vec<TableWithJoins>, tables: &[TableWithJoins],
exprs: &mut Vec<Expr>, exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>, table_names: &mut Vec<Vec<Ident>>,
) { ) {
@@ -625,37 +618,21 @@ impl QueryRouter {
exprs: &mut Vec<Expr>, exprs: &mut Vec<Expr>,
table_names: &mut Vec<Vec<Ident>>, table_names: &mut Vec<Vec<Ident>>,
) { ) {
match &table.relation { if let TableFactor::Table { name, .. } = &table.relation {
TableFactor::Table { name, .. } => { Self::process_table(name, table_names);
Self::process_table(name, table_names);
}
_ => (),
}; };
// Get table names from all the joins. // Get table names from all the joins.
for join in table.joins.iter() { for join in table.joins.iter() {
match &join.relation { if let TableFactor::Table { name, .. } = &join.relation {
TableFactor::Table { name, .. } => { Self::process_table(name, table_names);
Self::process_table(name, table_names);
}
_ => (),
}; };
// We can filter results based on join conditions, e.g. // We can filter results based on join conditions, e.g.
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5; // SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
match &join.join_operator { if let JoinOperator::Inner(JoinConstraint::On(expr)) = &join.join_operator {
JoinOperator::Inner(inner_join) => match &inner_join { // Parse the selection criteria later.
JoinConstraint::On(expr) => { exprs.push(expr.clone());
// Parse the selection criteria later.
exprs.push(expr.clone());
}
_ => (),
},
_ => (),
}; };
} }
} }
@@ -814,7 +791,7 @@ impl QueryRouter {
.automatic_sharding_key .automatic_sharding_key
.as_ref() .as_ref()
.unwrap() .unwrap()
.split(".") .split('.')
.map(|ident| Ident::new(ident.to_lowercase())) .map(|ident| Ident::new(ident.to_lowercase()))
.collect::<Vec<Ident>>(); .collect::<Vec<Ident>>();
@@ -822,12 +799,12 @@ impl QueryRouter {
assert_eq!(sharding_key.len(), 2); assert_eq!(sharding_key.len(), 2);
for a in assignments { for a in assignments {
if sharding_key[0].value == "*" { if sharding_key[0].value == "*"
if sharding_key[1].value == a.id.last().unwrap().value.to_lowercase() { && sharding_key[1].value == a.id.last().unwrap().value.to_lowercase()
return Err(Error::QueryRouterParserError( {
"Sharding key cannot be updated.".into(), return Err(Error::QueryRouterParserError(
)); "Sharding key cannot be updated.".into(),
} ));
} }
} }
Ok(()) Ok(())
@@ -844,7 +821,7 @@ impl QueryRouter {
.automatic_sharding_key .automatic_sharding_key
.as_ref() .as_ref()
.unwrap() .unwrap()
.split(".") .split('.')
.map(|ident| Ident::new(ident.to_lowercase())) .map(|ident| Ident::new(ident.to_lowercase()))
.collect::<Vec<Ident>>(); .collect::<Vec<Ident>>();
@@ -861,7 +838,7 @@ impl QueryRouter {
Expr::Identifier(ident) => { Expr::Identifier(ident) => {
// Only if we're dealing with only one table // Only if we're dealing with only one table
// and there is no ambiguity // and there is no ambiguity
if &ident.value.to_lowercase() == &sharding_key[1].value { if ident.value.to_lowercase() == sharding_key[1].value {
// Sharding key is unique enough, don't worry about // Sharding key is unique enough, don't worry about
// table names. // table names.
if &sharding_key[0].value == "*" { if &sharding_key[0].value == "*" {
@@ -874,13 +851,13 @@ impl QueryRouter {
// SELECT * FROM t WHERE sharding_key = 5 // SELECT * FROM t WHERE sharding_key = 5
// Make sure the table name from the sharding key matches // Make sure the table name from the sharding key matches
// the table name from the query. // the table name from the query.
found = &sharding_key[0].value == &table[0].value.to_lowercase(); found = sharding_key[0].value == table[0].value.to_lowercase();
} else if table.len() == 2 { } else if table.len() == 2 {
// Table name is fully qualified with the schema: e.g. // Table name is fully qualified with the schema: e.g.
// SELECT * FROM public.t WHERE sharding_key = 5 // SELECT * FROM public.t WHERE sharding_key = 5
// Ignore the schema (TODO: at some point, we want schema support) // Ignore the schema (TODO: at some point, we want schema support)
// and use the table name only. // and use the table name only.
found = &sharding_key[0].value == &table[1].value.to_lowercase(); found = sharding_key[0].value == table[1].value.to_lowercase();
} else { } else {
debug!("Got table name with more than two idents, which is not possible"); debug!("Got table name with more than two idents, which is not possible");
} }
@@ -893,8 +870,8 @@ impl QueryRouter {
// it will exist or Postgres will throw an error. // it will exist or Postgres will throw an error.
if idents.len() == 2 { if idents.len() == 2 {
found = (&sharding_key[0].value == "*" found = (&sharding_key[0].value == "*"
|| &sharding_key[0].value == &idents[0].value.to_lowercase()) || sharding_key[0].value == idents[0].value.to_lowercase())
&& &sharding_key[1].value == &idents[1].value.to_lowercase(); && sharding_key[1].value == idents[1].value.to_lowercase();
} }
// TODO: key can have schema as well, e.g. public.data.id (len == 3) // TODO: key can have schema as well, e.g. public.data.id (len == 3)
} }
@@ -926,7 +903,7 @@ impl QueryRouter {
} }
Expr::Value(Value::Placeholder(placeholder)) => { Expr::Value(Value::Placeholder(placeholder)) => {
match placeholder.replace("$", "").parse::<i16>() { match placeholder.replace('$', "").parse::<i16>() {
Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)), Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)),
Err(_) => { Err(_) => {
debug!( debug!(
@@ -1020,16 +997,16 @@ impl QueryRouter {
db: &self.pool_settings.db, db: &self.pool_settings.db,
}; };
let _ = query_logger.run(&self, ast).await; let _ = query_logger.run(self, ast).await;
} }
if let Some(ref intercept) = plugins.intercept { if let Some(ref intercept) = plugins.intercept {
let mut intercept = Intercept { let mut intercept = Intercept {
enabled: intercept.enabled, enabled: intercept.enabled,
config: &intercept, config: intercept,
}; };
let result = intercept.run(&self, ast).await; let result = intercept.run(self, ast).await;
if let Ok(PluginOutput::Intercept(output)) = result { if let Ok(PluginOutput::Intercept(output)) = result {
return Ok(PluginOutput::Intercept(output)); return Ok(PluginOutput::Intercept(output));
@@ -1042,7 +1019,7 @@ impl QueryRouter {
tables: &table_access.tables, tables: &table_access.tables,
}; };
let result = table_access.run(&self, ast).await; let result = table_access.run(self, ast).await;
if let Ok(PluginOutput::Deny(error)) = result { if let Ok(PluginOutput::Deny(error)) = result {
return Ok(PluginOutput::Deny(error)); return Ok(PluginOutput::Deny(error));
@@ -1078,7 +1055,7 @@ impl QueryRouter {
/// Should we attempt to parse queries? /// Should we attempt to parse queries?
pub fn query_parser_enabled(&self) -> bool { pub fn query_parser_enabled(&self) -> bool {
let enabled = match self.query_parser_enabled { match self.query_parser_enabled {
None => { None => {
debug!( debug!(
"Using pool settings, query_parser_enabled: {}", "Using pool settings, query_parser_enabled: {}",
@@ -1094,9 +1071,7 @@ impl QueryRouter {
); );
value value
} }
}; }
enabled
} }
pub fn primary_reads_enabled(&self) -> bool { pub fn primary_reads_enabled(&self) -> bool {
@@ -1107,6 +1082,12 @@ impl QueryRouter {
} }
} }
impl Default for QueryRouter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use super::*; use super::*;
@@ -1128,10 +1109,14 @@ mod test {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
qr.pool_settings.query_parser_read_write_splitting = true; qr.pool_settings.query_parser_read_write_splitting = true;
assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None); assert!(qr
.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"))
.is_some());
assert!(qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
let queries = vec![ let queries = vec![
simple_query("SELECT * FROM items WHERE id = 5"), simple_query("SELECT * FROM items WHERE id = 5"),
@@ -1173,7 +1158,9 @@ mod test {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
let query = simple_query("SELECT * FROM items WHERE id = 5"); let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None); assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO on"))
.is_some());
assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok()); assert!(qr.infer(&qr.parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), None); assert_eq!(qr.role(), None);
@@ -1186,7 +1173,9 @@ mod test {
qr.pool_settings.query_parser_read_write_splitting = true; qr.pool_settings.query_parser_read_write_splitting = true;
qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")); qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"));
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
let prepared_stmt = BytesMut::from( let prepared_stmt = BytesMut::from(
&b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..], &b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..],
@@ -1356,9 +1345,11 @@ mod test {
qr.pool_settings.query_parser_read_write_splitting = true; qr.pool_settings.query_parser_read_write_splitting = true;
let query = simple_query("SET SERVER ROLE TO 'auto'"); let query = simple_query("SET SERVER ROLE TO 'auto'");
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
assert!(qr.try_execute_command(&query) != None); assert!(qr.try_execute_command(&query).is_some());
assert!(qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
assert_eq!(qr.role(), None); assert_eq!(qr.role(), None);
@@ -1372,7 +1363,7 @@ mod test {
assert!(qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
let query = simple_query("SET SERVER ROLE TO 'default'"); let query = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(&query) != None); assert!(qr.try_execute_command(&query).is_some());
assert!(!qr.query_parser_enabled()); assert!(!qr.query_parser_enabled());
} }
@@ -1420,11 +1411,11 @@ mod test {
assert!(!qr.primary_reads_enabled()); assert!(!qr.primary_reads_enabled());
let q1 = simple_query("SET SERVER ROLE TO 'primary'"); let q1 = simple_query("SET SERVER ROLE TO 'primary'");
assert!(qr.try_execute_command(&q1) != None); assert!(qr.try_execute_command(&q1).is_some());
assert_eq!(qr.active_role.unwrap(), Role::Primary); assert_eq!(qr.active_role.unwrap(), Role::Primary);
let q2 = simple_query("SET SERVER ROLE TO 'default'"); let q2 = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(&q2) != None); assert!(qr.try_execute_command(&q2).is_some());
assert_eq!(qr.active_role.unwrap(), pool_settings.default_role); assert_eq!(qr.active_role.unwrap(), pool_settings.default_role);
} }
@@ -1485,29 +1476,29 @@ mod test {
}; };
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings.clone()); qr.update_pool_settings(pool_settings);
// Shard should start out unset // Shard should start out unset
assert_eq!(qr.active_shard, None); assert_eq!(qr.active_shard, None);
// Don't panic when short query eg. ; is sent // Don't panic when short query eg. ; is sent
let q0 = simple_query(";"); let q0 = simple_query(";");
assert!(qr.try_execute_command(&q0) == None); assert!(qr.try_execute_command(&q0).is_none());
assert_eq!(qr.active_shard, None); assert_eq!(qr.active_shard, None);
// Make sure setting it works // Make sure setting it works
let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;"); let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;");
assert!(qr.try_execute_command(&q1) == None); assert!(qr.try_execute_command(&q1).is_none());
assert_eq!(qr.active_shard, Some(1)); assert_eq!(qr.active_shard, Some(1));
// And make sure changing it works // And make sure changing it works
let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;"); let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;");
assert!(qr.try_execute_command(&q2) == None); assert!(qr.try_execute_command(&q2).is_none());
assert_eq!(qr.active_shard, Some(0)); assert_eq!(qr.active_shard, Some(0));
// Validate setting by shard with expected shard copied from sharding.rs tests // Validate setting by shard with expected shard copied from sharding.rs tests
let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;"); let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;");
assert!(qr.try_execute_command(&q2) == None); assert!(qr.try_execute_command(&q2).is_none());
assert_eq!(qr.active_shard, Some(2)); assert_eq!(qr.active_shard, Some(2));
} }
@@ -1863,10 +1854,11 @@ mod test {
}; };
QueryRouter::setup(); QueryRouter::setup();
let mut pool_settings = PoolSettings::default(); let pool_settings = PoolSettings {
pool_settings.query_parser_enabled = true; query_parser_enabled: true,
pool_settings.plugins = Some(plugins); plugins: Some(plugins),
..Default::default()
};
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings); qr.update_pool_settings(pool_settings);

View File

@@ -79,12 +79,12 @@ impl ScramSha256 {
let server_message = Message::parse(message)?; let server_message = Message::parse(message)?;
if !server_message.nonce.starts_with(&self.nonce) { if !server_message.nonce.starts_with(&self.nonce) {
return Err(Error::ProtocolSyncError(format!("SCRAM"))); return Err(Error::ProtocolSyncError("SCRAM".to_string()));
} }
let salt = match general_purpose::STANDARD.decode(&server_message.salt) { let salt = match general_purpose::STANDARD.decode(&server_message.salt) {
Ok(salt) => salt, Ok(salt) => salt,
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
}; };
let salted_password = Self::hi( let salted_password = Self::hi(
@@ -166,9 +166,9 @@ impl ScramSha256 {
pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> { pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> {
let final_message = FinalMessage::parse(message)?; let final_message = FinalMessage::parse(message)?;
let verifier = match general_purpose::STANDARD.decode(&final_message.value) { let verifier = match general_purpose::STANDARD.decode(final_message.value) {
Ok(verifier) => verifier, Ok(verifier) => verifier,
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
}; };
let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) { let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
@@ -230,14 +230,14 @@ impl Message {
.collect::<Vec<String>>(); .collect::<Vec<String>>();
if parts.len() != 3 { if parts.len() != 3 {
return Err(Error::ProtocolSyncError(format!("SCRAM"))); return Err(Error::ProtocolSyncError("SCRAM".to_string()));
} }
let nonce = str::replace(&parts[0], "r=", ""); let nonce = str::replace(&parts[0], "r=", "");
let salt = str::replace(&parts[1], "s=", ""); let salt = str::replace(&parts[1], "s=", "");
let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() { let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
Ok(iterations) => iterations, Ok(iterations) => iterations,
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
}; };
Ok(Message { Ok(Message {
@@ -257,7 +257,7 @@ impl FinalMessage {
/// Parse the server final validation message. /// Parse the server final validation message.
pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> { pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
if !message.starts_with(b"v=") || message.len() < 4 { if !message.starts_with(b"v=") || message.len() < 4 {
return Err(Error::ProtocolSyncError(format!("SCRAM"))); return Err(Error::ProtocolSyncError("SCRAM".to_string()));
} }
Ok(FinalMessage { Ok(FinalMessage {

View File

@@ -197,12 +197,8 @@ impl ServerParameters {
key = "DateStyle".to_string(); key = "DateStyle".to_string();
}; };
if TRACKED_PARAMETERS.contains(&key) { if TRACKED_PARAMETERS.contains(&key) || startup {
self.parameters.insert(key, value); self.parameters.insert(key, value);
} else {
if startup {
self.parameters.insert(key, value);
}
} }
} }
@@ -332,6 +328,7 @@ pub struct Server {
impl Server { impl Server {
/// Pretend to be the Postgres client and connect to the server given host, port and credentials. /// Pretend to be the Postgres client and connect to the server given host, port and credentials.
/// Perform the authentication and return the server in a ready for query state. /// Perform the authentication and return the server in a ready for query state.
#[allow(clippy::too_many_arguments)]
pub async fn startup( pub async fn startup(
address: &Address, address: &Address,
user: &User, user: &User,
@@ -440,10 +437,7 @@ impl Server {
// Something else? // Something else?
m => { m => {
return Err(Error::SocketError(format!( return Err(Error::SocketError(format!("Unknown message: {}", { m })));
"Unknown message: {}",
m as char
)));
} }
} }
} else { } else {
@@ -461,6 +455,8 @@ impl Server {
None => &user.username, None => &user.username,
}; };
#[allow(clippy::match_as_ref)]
#[allow(clippy::manual_map)]
let password = match user.server_password { let password = match user.server_password {
Some(ref server_password) => Some(server_password), Some(ref server_password) => Some(server_password),
None => match user.password { None => match user.password {
@@ -473,14 +469,11 @@ impl Server {
let mut process_id: i32 = 0; let mut process_id: i32 = 0;
let mut secret_key: i32 = 0; let mut secret_key: i32 = 0;
let server_identifier = ServerIdentifier::new(username, &database); let server_identifier = ServerIdentifier::new(username, database);
// We'll be handling multiple packets, but they will all be structured the same. // We'll be handling multiple packets, but they will all be structured the same.
// We'll loop here until this exchange is complete. // We'll loop here until this exchange is complete.
let mut scram: Option<ScramSha256> = match password { let mut scram: Option<ScramSha256> = password.map(|password| ScramSha256::new(password));
Some(password) => Some(ScramSha256::new(password)),
None => None,
};
let mut server_parameters = ServerParameters::new(); let mut server_parameters = ServerParameters::new();
@@ -882,7 +875,7 @@ impl Server {
self.mirror_send(messages); self.mirror_send(messages);
self.stats().data_sent(messages.len()); self.stats().data_sent(messages.len());
match write_all_flush(&mut self.stream, &messages).await { match write_all_flush(&mut self.stream, messages).await {
Ok(_) => { Ok(_) => {
// Successfully sent to server // Successfully sent to server
self.last_activity = SystemTime::now(); self.last_activity = SystemTime::now();
@@ -1359,16 +1352,14 @@ impl Server {
} }
pub fn mirror_send(&mut self, bytes: &BytesMut) { pub fn mirror_send(&mut self, bytes: &BytesMut) {
match self.mirror_manager.as_mut() { if let Some(manager) = self.mirror_manager.as_mut() {
Some(manager) => manager.send(bytes), manager.send(bytes)
None => (),
} }
} }
pub fn mirror_disconnect(&mut self) { pub fn mirror_disconnect(&mut self) {
match self.mirror_manager.as_mut() { if let Some(manager) = self.mirror_manager.as_mut() {
Some(manager) => manager.disconnect(), manager.disconnect()
None => (),
} }
} }
@@ -1397,7 +1388,7 @@ impl Server {
server.send(&simple_query(query)).await?; server.send(&simple_query(query)).await?;
let mut message = server.recv(None).await?; let mut message = server.recv(None).await?;
Ok(parse_query_message(&mut message).await?) parse_query_message(&mut message).await
} }
} }

View File

@@ -64,7 +64,7 @@ impl Sharder {
fn sha1(&self, key: i64) -> usize { fn sha1(&self, key: i64) -> usize {
let mut hasher = Sha1::new(); let mut hasher = Sha1::new();
hasher.update(&key.to_string().as_bytes()); hasher.update(key.to_string().as_bytes());
let result = hasher.finalize(); let result = hasher.finalize();
@@ -202,10 +202,10 @@ mod test {
#[test] #[test]
fn test_sha1_hash() { fn test_sha1_hash() {
let sharder = Sharder::new(12, ShardingFunction::Sha1); let sharder = Sharder::new(12, ShardingFunction::Sha1);
let ids = vec![ let ids = [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
]; ];
let shards = vec![ let shards = [
4, 7, 8, 3, 6, 0, 0, 10, 3, 11, 1, 7, 4, 4, 11, 2, 5, 0, 8, 3, 4, 7, 8, 3, 6, 0, 0, 10, 3, 11, 1, 7, 4, 4, 11, 2, 5, 0, 8, 3,
]; ];

View File

@@ -86,11 +86,11 @@ impl PoolStats {
} }
} }
return map; map
} }
pub fn generate_header() -> Vec<(&'static str, DataType)> { pub fn generate_header() -> Vec<(&'static str, DataType)> {
return vec![ vec![
("database", DataType::Text), ("database", DataType::Text),
("user", DataType::Text), ("user", DataType::Text),
("pool_mode", DataType::Text), ("pool_mode", DataType::Text),
@@ -105,11 +105,11 @@ impl PoolStats {
("sv_login", DataType::Numeric), ("sv_login", DataType::Numeric),
("maxwait", DataType::Numeric), ("maxwait", DataType::Numeric),
("maxwait_us", DataType::Numeric), ("maxwait_us", DataType::Numeric),
]; ]
} }
pub fn generate_row(&self) -> Vec<String> { pub fn generate_row(&self) -> Vec<String> {
return vec![ vec![
self.identifier.db.clone(), self.identifier.db.clone(),
self.identifier.user.clone(), self.identifier.user.clone(),
self.mode.to_string(), self.mode.to_string(),
@@ -124,7 +124,7 @@ impl PoolStats {
self.sv_login.to_string(), self.sv_login.to_string(),
(self.maxwait / 1_000_000).to_string(), (self.maxwait / 1_000_000).to_string(),
(self.maxwait % 1_000_000).to_string(), (self.maxwait % 1_000_000).to_string(),
]; ]
} }
} }