chore: make clippy lint happy (#225)

* chore: make clippy happy

* chore: cargo fmt

* chore: cargo fmt
This commit is contained in:
Cluas
2022-11-10 02:04:31 +08:00
committed by GitHub
parent 4bd5717ab1
commit dfa26ec6f8
13 changed files with 167 additions and 207 deletions

View File

@@ -22,7 +22,7 @@ pub fn generate_server_info_for_admin() -> BytesMut {
server_info.put(server_parameter_message("server_version", VERSION)); server_info.put(server_parameter_message("server_version", VERSION));
server_info.put(server_parameter_message("DateStyle", "ISO, MDY")); server_info.put(server_parameter_message("DateStyle", "ISO, MDY"));
return server_info; server_info
} }
/// Handle admin client. /// Handle admin client.
@@ -179,7 +179,7 @@ where
let mut res = BytesMut::new(); let mut res = BytesMut::new();
res.put(row_description(&vec![("version", DataType::Text)])); res.put(row_description(&vec![("version", DataType::Text)]));
res.put(data_row(&vec![format!("PgCat {}", VERSION).to_string()])); res.put(data_row(&vec![format!("PgCat {}", VERSION)]));
res.put(command_complete("SHOW")); res.put(command_complete("SHOW"));
res.put_u8(b'Z'); res.put_u8(b'Z');

View File

@@ -377,7 +377,7 @@ where
let admin = ["pgcat", "pgbouncer"] let admin = ["pgcat", "pgbouncer"]
.iter() .iter()
.filter(|db| *db == &pool_name) .filter(|db| *db == pool_name)
.count() .count()
== 1; == 1;
@@ -389,7 +389,7 @@ where
); );
error_response_terminal( error_response_terminal(
&mut write, &mut write,
&format!("terminating connection due to administrator command"), "terminating connection due to administrator command",
) )
.await?; .await?;
return Err(Error::ShuttingDown); return Err(Error::ShuttingDown);
@@ -446,7 +446,7 @@ where
} }
// Authenticate normal user. // Authenticate normal user.
else { else {
let pool = match get_pool(&pool_name, &username) { let pool = match get_pool(pool_name, username) {
Some(pool) => pool, Some(pool) => pool,
None => { None => {
error_response( error_response(
@@ -464,7 +464,7 @@ where
}; };
// Compare server and client hashes. // Compare server and client hashes.
let password_hash = md5_hash_password(&username, &pool.settings.user.password, &salt); let password_hash = md5_hash_password(username, &pool.settings.user.password, &salt);
if password_hash != password_response { if password_hash != password_response {
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name); warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name);
@@ -487,9 +487,9 @@ where
trace!("Startup OK"); trace!("Startup OK");
return Ok(Client { Ok(Client {
read: BufReader::new(read), read: BufReader::new(read),
write: write, write,
addr, addr,
buffer: BytesMut::with_capacity(8196), buffer: BytesMut::with_capacity(8196),
cancel_mode: false, cancel_mode: false,
@@ -498,8 +498,8 @@ where
secret_key, secret_key,
client_server_map, client_server_map,
parameters: parameters.clone(), parameters: parameters.clone(),
stats: stats, stats,
admin: admin, admin,
last_address_id: None, last_address_id: None,
last_server_id: None, last_server_id: None,
pool_name: pool_name.clone(), pool_name: pool_name.clone(),
@@ -507,7 +507,7 @@ where
application_name: application_name.to_string(), application_name: application_name.to_string(),
shutdown, shutdown,
connected_to_server: false, connected_to_server: false,
}); })
} }
/// Handle cancel request. /// Handle cancel request.
@@ -521,9 +521,9 @@ where
) -> Result<Client<S, T>, Error> { ) -> Result<Client<S, T>, Error> {
let process_id = bytes.get_i32(); let process_id = bytes.get_i32();
let secret_key = bytes.get_i32(); let secret_key = bytes.get_i32();
return Ok(Client { Ok(Client {
read: BufReader::new(read), read: BufReader::new(read),
write: write, write,
addr, addr,
buffer: BytesMut::with_capacity(8196), buffer: BytesMut::with_capacity(8196),
cancel_mode: true, cancel_mode: true,
@@ -541,7 +541,7 @@ where
application_name: String::from("undefined"), application_name: String::from("undefined"),
shutdown, shutdown,
connected_to_server: false, connected_to_server: false,
}); })
} }
/// Handle a connected and authenticated client. /// Handle a connected and authenticated client.
@@ -557,12 +557,9 @@ where
// Drop the mutex as soon as possible. // Drop the mutex as soon as possible.
// We found the server the client is using for its query // We found the server the client is using for its query
// that it wants to cancel. // that it wants to cancel.
Some((process_id, secret_key, address, port)) => ( Some((process_id, secret_key, address, port)) => {
process_id.clone(), (*process_id, *secret_key, address.clone(), *port)
secret_key.clone(), }
address.clone(),
*port,
),
// The client doesn't know / got the wrong server, // The client doesn't know / got the wrong server,
// we're closing the connection for security reasons. // we're closing the connection for security reasons.
@@ -573,7 +570,7 @@ where
// Opens a new separate connection to the server, sends the backend_id // Opens a new separate connection to the server, sends the backend_id
// and secret_key and then closes it for security reasons. No other interactions // and secret_key and then closes it for security reasons. No other interactions
// take place. // take place.
return Ok(Server::cancel(&address, port, process_id, secret_key).await?); return Server::cancel(&address, port, process_id, secret_key).await;
} }
// The query router determines where the query is going to go, // The query router determines where the query is going to go,
@@ -606,7 +603,7 @@ where
if !self.admin { if !self.admin {
error_response_terminal( error_response_terminal(
&mut self.write, &mut self.write,
&format!("terminating connection due to administrator command") "terminating connection due to administrator command"
).await?; ).await?;
return Ok(()) return Ok(())
} }
@@ -998,14 +995,14 @@ where
) -> Result<(), Error> { ) -> Result<(), Error> {
debug!("Sending {} to server", code); debug!("Sending {} to server", code);
self.send_server_message(server, message, &address, &pool) self.send_server_message(server, message, address, pool)
.await?; .await?;
let query_start = Instant::now(); let query_start = Instant::now();
// Read all data the server has to offer, which can be multiple messages // Read all data the server has to offer, which can be multiple messages
// buffered in 8196 bytes chunks. // buffered in 8196 bytes chunks.
loop { loop {
let response = self.receive_server_message(server, &address, &pool).await?; let response = self.receive_server_message(server, address, pool).await?;
match write_all_half(&mut self.write, response).await { match write_all_half(&mut self.write, response).await {
Ok(_) => (), Ok(_) => (),

View File

@@ -9,7 +9,6 @@ use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use tokio::fs::File; use tokio::fs::File;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use toml;
use crate::errors::Error; use crate::errors::Error;
use crate::pool::{ClientServerMap, ConnectionPool}; use crate::pool::{ClientServerMap, ConnectionPool};
@@ -353,7 +352,7 @@ impl Shard {
let mut dup_check = HashSet::new(); let mut dup_check = HashSet::new();
let mut primary_count = 0; let mut primary_count = 0;
if self.servers.len() == 0 { if self.servers.is_empty() {
error!("Shard {} has no servers configured", self.database); error!("Shard {} has no servers configured", self.database);
return Err(Error::BadConfig); return Err(Error::BadConfig);
} }
@@ -362,10 +361,9 @@ impl Shard {
dup_check.insert(server); dup_check.insert(server);
// Check that we define only zero or one primary. // Check that we define only zero or one primary.
match server.role { if server.role == Role::Primary {
Role::Primary => primary_count += 1, primary_count += 1
_ => (), }
};
} }
if primary_count > 1 { if primary_count > 1 {
@@ -605,22 +603,17 @@ impl Config {
// Validate TLS! // Validate TLS!
match self.general.tls_certificate.clone() { match self.general.tls_certificate.clone() {
Some(tls_certificate) => { Some(tls_certificate) => {
match load_certs(&Path::new(&tls_certificate)) { match load_certs(Path::new(&tls_certificate)) {
Ok(_) => { Ok(_) => {
// Cert is okay, but what about the private key? // Cert is okay, but what about the private key?
match self.general.tls_private_key.clone() { match self.general.tls_private_key.clone() {
Some(tls_private_key) => { Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) {
match load_keys(&Path::new(&tls_private_key)) { Ok(_) => (),
Ok(_) => (), Err(err) => {
Err(err) => { error!("tls_private_key is incorrectly configured: {:?}", err);
error!( return Err(Error::BadConfig);
"tls_private_key is incorrectly configured: {:?}",
err
);
return Err(Error::BadConfig);
}
} }
} },
None => { None => {
error!("tls_certificate is set, but the tls_private_key is not"); error!("tls_certificate is set, but the tls_private_key is not");
@@ -638,7 +631,7 @@ impl Config {
None => (), None => (),
}; };
for (_, pool) in &mut self.pools { for pool in self.pools.values_mut() {
pool.validate()?; pool.validate()?;
} }

View File

@@ -75,7 +75,6 @@ mod stats;
mod tls; mod tls;
use crate::config::{get_config, reload_config, VERSION}; use crate::config::{get_config, reload_config, VERSION};
use crate::errors::Error;
use crate::pool::{ClientServerMap, ConnectionPool}; use crate::pool::{ClientServerMap, ConnectionPool};
use crate::prometheus::start_metric_server; use crate::prometheus::start_metric_server;
use crate::stats::{Collector, Reporter, REPORTER}; use crate::stats::{Collector, Reporter, REPORTER};
@@ -171,13 +170,10 @@ async fn main() {
if config.general.autoreload { if config.general.autoreload {
info!("Automatically reloading config"); info!("Automatically reloading config");
match reload_config(autoreload_client_server_map.clone()).await { if let Ok(changed) = reload_config(autoreload_client_server_map.clone()).await {
Ok(changed) => { if changed {
if changed { get_config().show()
get_config().show()
}
} }
Err(_) => (),
}; };
} }
} }
@@ -202,10 +198,7 @@ async fn main() {
_ = sighup_signal.recv() => { _ = sighup_signal.recv() => {
info!("Reloading config"); info!("Reloading config");
match reload_config(client_server_map.clone()).await { _ = reload_config(client_server_map.clone()).await;
Ok(_) => (),
Err(_) => (),
};
get_config().show(); get_config().show();
}, },
@@ -278,14 +271,6 @@ async fn main() {
} }
Err(err) => { Err(err) => {
match err {
// Don't count the clients we rejected.
Error::ShuttingDown => (),
_ => {
// drain_tx.send(-1).await.unwrap();
}
}
warn!("Client disconnected with error {:?}", err); warn!("Client disconnected with error {:?}", err);
} }
}; };

View File

@@ -38,7 +38,7 @@ where
auth_ok.put_i32(8); auth_ok.put_i32(8);
auth_ok.put_i32(0); auth_ok.put_i32(0);
Ok(write_all(stream, auth_ok).await?) write_all(stream, auth_ok).await
} }
/// Generate md5 password challenge. /// Generate md5 password challenge.
@@ -79,7 +79,7 @@ where
key_data.put_i32(backend_id); key_data.put_i32(backend_id);
key_data.put_i32(secret_key); key_data.put_i32(secret_key);
Ok(write_all(stream, key_data).await?) write_all(stream, key_data).await
} }
/// Construct a `Q`: Query message. /// Construct a `Q`: Query message.
@@ -88,7 +88,7 @@ pub fn simple_query(query: &str) -> BytesMut {
let query = format!("{}\0", query); let query = format!("{}\0", query);
res.put_i32(query.len() as i32 + 4); res.put_i32(query.len() as i32 + 4);
res.put_slice(&query.as_bytes()); res.put_slice(query.as_bytes());
res res
} }
@@ -106,7 +106,7 @@ where
bytes.put_i32(5); bytes.put_i32(5);
bytes.put_u8(b'I'); // Idle bytes.put_u8(b'I'); // Idle
Ok(write_all(stream, bytes).await?) write_all(stream, bytes).await
} }
/// Send the startup packet the server. We're pretending we're a Pg client. /// Send the startup packet the server. We're pretending we're a Pg client.
@@ -118,12 +118,12 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
// User // User
bytes.put(&b"user\0"[..]); bytes.put(&b"user\0"[..]);
bytes.put_slice(&user.as_bytes()); bytes.put_slice(user.as_bytes());
bytes.put_u8(0); bytes.put_u8(0);
// Database // Database
bytes.put(&b"database\0"[..]); bytes.put(&b"database\0"[..]);
bytes.put_slice(&database.as_bytes()); bytes.put_slice(database.as_bytes());
bytes.put_u8(0); bytes.put_u8(0);
bytes.put_u8(0); // Null terminator bytes.put_u8(0); // Null terminator
@@ -136,7 +136,7 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
match stream.write_all(&startup).await { match stream.write_all(&startup).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(_) => return Err(Error::SocketError), Err(_) => Err(Error::SocketError),
} }
} }
@@ -155,7 +155,7 @@ pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Erro
c = bytes.get_u8(); c = bytes.get_u8();
} }
if tmp.len() > 0 { if !tmp.is_empty() {
buf.push(tmp.clone()); buf.push(tmp.clone());
tmp.clear(); tmp.clear();
} }
@@ -234,7 +234,7 @@ where
message.put_i32(password.len() as i32 + 4); message.put_i32(password.len() as i32 + 4);
message.put_slice(&password[..]); message.put_slice(&password[..]);
Ok(write_all(stream, message).await?) write_all(stream, message).await
} }
/// Implements a response to our custom `SET SHARDING KEY` /// Implements a response to our custom `SET SHARDING KEY`
@@ -292,7 +292,7 @@ where
// The short error message. // The short error message.
error.put_u8(b'M'); error.put_u8(b'M');
error.put_slice(&format!("{}\0", message).as_bytes()); error.put_slice(format!("{}\0", message).as_bytes());
// No more fields follow. // No more fields follow.
error.put_u8(0); error.put_u8(0);
@@ -304,7 +304,7 @@ where
res.put_i32(error.len() as i32 + 4); res.put_i32(error.len() as i32 + 4);
res.put(error); res.put(error);
Ok(write_all_half(stream, res).await?) write_all_half(stream, res).await
} }
pub async fn wrong_password<S>(stream: &mut S, user: &str) -> Result<(), Error> pub async fn wrong_password<S>(stream: &mut S, user: &str) -> Result<(), Error>
@@ -327,7 +327,7 @@ where
// The short error message. // The short error message.
error.put_u8(b'M'); error.put_u8(b'M');
error.put_slice(&format!("password authentication failed for user \"{}\"\0", user).as_bytes()); error.put_slice(format!("password authentication failed for user \"{}\"\0", user).as_bytes());
// No more fields follow. // No more fields follow.
error.put_u8(0); error.put_u8(0);
@@ -379,7 +379,7 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
for (name, data_type) in columns { for (name, data_type) in columns {
// Column name // Column name
row_desc.put_slice(&format!("{}\0", name).as_bytes()); row_desc.put_slice(format!("{}\0", name).as_bytes());
// Doesn't belong to any table // Doesn't belong to any table
row_desc.put_i32(0); row_desc.put_i32(0);
@@ -423,7 +423,7 @@ pub fn data_row(row: &Vec<String>) -> BytesMut {
for column in row { for column in row {
let column = column.as_bytes(); let column = column.as_bytes();
data_row.put_i32(column.len() as i32); data_row.put_i32(column.len() as i32);
data_row.put_slice(&column); data_row.put_slice(column);
} }
res.put_u8(b'D'); res.put_u8(b'D');
@@ -450,7 +450,7 @@ where
{ {
match stream.write_all(&buf).await { match stream.write_all(&buf).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(_) => return Err(Error::SocketError), Err(_) => Err(Error::SocketError),
} }
} }
@@ -461,7 +461,7 @@ where
{ {
match stream.write_all(&buf).await { match stream.write_all(&buf).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(_) => return Err(Error::SocketError), Err(_) => Err(Error::SocketError),
} }
} }
@@ -510,5 +510,5 @@ pub fn server_parameter_message(key: &str, value: &str) -> BytesMut {
server_info.put_slice(value.as_bytes()); server_info.put_slice(value.as_bytes());
server_info.put_bytes(0, 1); server_info.put_bytes(0, 1);
return server_info; server_info
} }

View File

@@ -140,18 +140,18 @@ impl ConnectionPool {
let changed = pools_hash.insert(pool_config.clone()); let changed = pools_hash.insert(pool_config.clone());
// There is one pool per database/user pair. // There is one pool per database/user pair.
for (_, user) in &pool_config.users { for user in pool_config.users.values() {
// If the pool hasn't changed, get existing reference and insert it into the new_pools. // If the pool hasn't changed, get existing reference and insert it into the new_pools.
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
if !changed { if !changed {
match get_pool(&pool_name, &user.username) { match get_pool(pool_name, &user.username) {
Some(pool) => { Some(pool) => {
info!( info!(
"[pool: {}][user: {}] has not changed", "[pool: {}][user: {}] has not changed",
pool_name, user.username pool_name, user.username
); );
new_pools.insert( new_pools.insert(
PoolIdentifier::new(&pool_name, &user.username), PoolIdentifier::new(pool_name, &user.username),
pool.clone(), pool.clone(),
); );
continue; continue;
@@ -172,7 +172,6 @@ impl ConnectionPool {
.shards .shards
.clone() .clone()
.into_keys() .into_keys()
.map(|x| x.to_string())
.collect::<Vec<String>>(); .collect::<Vec<String>>();
// Sort by shard number to ensure consistency. // Sort by shard number to ensure consistency.
@@ -182,10 +181,9 @@ impl ConnectionPool {
let shard = &pool_config.shards[shard_idx]; let shard = &pool_config.shards[shard_idx];
let mut pools = Vec::new(); let mut pools = Vec::new();
let mut servers = Vec::new(); let mut servers = Vec::new();
let mut address_index = 0;
let mut replica_number = 0; let mut replica_number = 0;
for server in shard.servers.iter() { for (address_index, server) in shard.servers.iter().enumerate() {
let address = Address { let address = Address {
id: address_id, id: address_id,
database: shard.database.clone(), database: shard.database.clone(),
@@ -200,7 +198,6 @@ impl ConnectionPool {
}; };
address_id += 1; address_id += 1;
address_index += 1;
if server.role == Role::Replica { if server.role == Role::Replica {
replica_number += 1; replica_number += 1;
@@ -240,7 +237,7 @@ impl ConnectionPool {
let mut pool = ConnectionPool { let mut pool = ConnectionPool {
databases: shards, databases: shards,
addresses: addresses, addresses,
banlist: Arc::new(RwLock::new(banlist)), banlist: Arc::new(RwLock::new(banlist)),
stats: get_reporter(), stats: get_reporter(),
server_info: BytesMut::new(), server_info: BytesMut::new(),
@@ -255,7 +252,7 @@ impl ConnectionPool {
"primary" => Some(Role::Primary), "primary" => Some(Role::Primary),
_ => unreachable!(), _ => unreachable!(),
}, },
query_parser_enabled: pool_config.query_parser_enabled.clone(), query_parser_enabled: pool_config.query_parser_enabled,
primary_reads_enabled: pool_config.primary_reads_enabled, primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function, sharding_function: pool_config.sharding_function,
automatic_sharding_key: pool_config.automatic_sharding_key.clone(), automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
@@ -273,7 +270,7 @@ impl ConnectionPool {
}; };
// There is one pool per database/user pair. // There is one pool per database/user pair.
new_pools.insert(PoolIdentifier::new(&pool_name, &user.username), pool); new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
} }
} }
@@ -304,7 +301,7 @@ impl ConnectionPool {
let server = &*proxy; let server = &*proxy;
let server_info = server.server_info(); let server_info = server.server_info();
if server_infos.len() > 0 { if !server_infos.is_empty() {
// Compare against the last server checked. // Compare against the last server checked.
if server_info != server_infos[server_infos.len() - 1] { if server_info != server_infos[server_infos.len() - 1] {
warn!( warn!(
@@ -320,7 +317,7 @@ impl ConnectionPool {
// TODO: compare server information to make sure // TODO: compare server information to make sure
// all shards are running identical configurations. // all shards are running identical configurations.
if server_infos.len() == 0 { if server_infos.is_empty() {
return Err(Error::AllServersDown); return Err(Error::AllServersDown);
} }
@@ -356,7 +353,7 @@ impl ConnectionPool {
None => break, None => break,
}; };
if self.is_banned(&address, role) { if self.is_banned(address, role) {
debug!("Address {:?} is banned", address); debug!("Address {:?} is banned", address);
continue; continue;
} }
@@ -373,7 +370,7 @@ impl ConnectionPool {
Ok(conn) => conn, Ok(conn) => conn,
Err(err) => { Err(err) => {
error!("Banning instance {:?}, error: {:?}", address, err); error!("Banning instance {:?}, error: {:?}", address, err);
self.ban(&address, process_id); self.ban(address, process_id);
self.stats.client_checkout_error(process_id, address.id); self.stats.client_checkout_error(process_id, address.id);
continue; continue;
} }
@@ -428,7 +425,7 @@ impl ConnectionPool {
// Don't leave a bad connection in the pool. // Don't leave a bad connection in the pool.
server.mark_bad(); server.mark_bad();
self.ban(&address, process_id); self.ban(address, process_id);
continue; continue;
} }
}, },
@@ -442,7 +439,7 @@ impl ConnectionPool {
// Don't leave a bad connection in the pool. // Don't leave a bad connection in the pool.
server.mark_bad(); server.mark_bad();
self.ban(&address, process_id); self.ban(address, process_id);
continue; continue;
} }
} }
@@ -575,11 +572,11 @@ impl ServerPool {
stats: Reporter, stats: Reporter,
) -> ServerPool { ) -> ServerPool {
ServerPool { ServerPool {
address: address, address,
user: user, user,
database: database.to_string(), database: database.to_string(),
client_server_map: client_server_map, client_server_map,
stats: stats, stats,
} }
} }
} }
@@ -638,15 +635,14 @@ impl ManageConnection for ServerPool {
/// Get the connection pool /// Get the connection pool
pub fn get_pool(db: &str, user: &str) -> Option<ConnectionPool> { pub fn get_pool(db: &str, user: &str) -> Option<ConnectionPool> {
match (*(*POOLS.load())).get(&PoolIdentifier::new(db, user)) { (*(*POOLS.load()))
Some(pool) => Some(pool.clone()), .get(&PoolIdentifier::new(db, user))
None => None, .cloned()
}
} }
/// Get a pointer to all configured pools. /// Get a pointer to all configured pools.
pub fn get_all_pools() -> HashMap<PoolIdentifier, ConnectionPool> { pub fn get_all_pools() -> HashMap<PoolIdentifier, ConnectionPool> {
return (*(*POOLS.load())).clone(); (*(*POOLS.load())).clone()
} }
/// How many total servers we have in the config. /// How many total servers we have in the config.

View File

@@ -203,7 +203,7 @@ async fn prometheus_stats(request: Request<Body>) -> Result<Response<Body>, hype
pub async fn start_metric_server(http_addr: SocketAddr) { pub async fn start_metric_server(http_addr: SocketAddr) {
let http_service_factory = let http_service_factory =
make_service_fn(|_conn| async { Ok::<_, hyper::Error>(service_fn(prometheus_stats)) }); make_service_fn(|_conn| async { Ok::<_, hyper::Error>(service_fn(prometheus_stats)) });
let server = Server::bind(&http_addr.into()).serve(http_service_factory); let server = Server::bind(&http_addr).serve(http_service_factory);
info!( info!(
"Exposing prometheus metrics on http://{}/metrics.", "Exposing prometheus metrics on http://{}/metrics.",
http_addr http_addr

View File

@@ -86,10 +86,7 @@ impl QueryRouter {
Err(_) => return false, Err(_) => return false,
}; };
match CUSTOM_SQL_REGEX_SET.set(set) { CUSTOM_SQL_REGEX_SET.set(set).is_ok()
Ok(_) => true,
Err(_) => false,
}
} }
/// Create a new instance of the query router. /// Create a new instance of the query router.
@@ -276,7 +273,6 @@ impl QueryRouter {
// Parse (prepared statement) // Parse (prepared statement)
'P' => { 'P' => {
let mut start = 0; let mut start = 0;
let mut end;
// Skip the name of the prepared statement. // Skip the name of the prepared statement.
while buf[start] != 0 && start < buf.len() { while buf[start] != 0 && start < buf.len() {
@@ -285,7 +281,7 @@ impl QueryRouter {
start += 1; // Skip terminating null start += 1; // Skip terminating null
// Find the end of the prepared stmt (\0) // Find the end of the prepared stmt (\0)
end = start; let mut end = start;
while buf[end] != 0 && end < buf.len() { while buf[end] != 0 && end < buf.len() {
end += 1; end += 1;
} }
@@ -294,7 +290,7 @@ impl QueryRouter {
debug!("Prepared statement: '{}'", query); debug!("Prepared statement: '{}'", query);
query.replace("$", "") // Remove placeholders turning them into "values" query.replace('$', "") // Remove placeholders turning them into "values"
} }
_ => return false, _ => return false,
@@ -312,7 +308,7 @@ impl QueryRouter {
debug!("AST: {:?}", ast); debug!("AST: {:?}", ast);
if ast.len() == 0 { if ast.is_empty() {
// That's weird, no idea, let's go to primary // That's weird, no idea, let's go to primary
self.active_role = Some(Role::Primary); self.active_role = Some(Role::Primary);
return false; return false;
@@ -371,50 +367,46 @@ impl QueryRouter {
let mut result = Vec::new(); let mut result = Vec::new();
let mut found = false; let mut found = false;
match expr { // This parses `sharding_key = 5`. But it's technically
// This parses `sharding_key = 5`. But it's technically // legal to write `5 = sharding_key`. I don't judge the people
// legal to write `5 = sharding_key`. I don't judge the people // who do that, but I think ORMs will still use the first variant,
// who do that, but I think ORMs will still use the first variant, // so we can leave the second as a TODO.
// so we can leave the second as a TODO. if let Expr::BinaryOp { left, op, right } = expr {
Expr::BinaryOp { left, op, right } => { match &**left {
match &**left { Expr::BinaryOp { .. } => result.extend(self.selection_parser(left)),
Expr::BinaryOp { .. } => result.extend(self.selection_parser(&left)), Expr::Identifier(ident) => {
Expr::Identifier(ident) => { found =
found = ident.value ident.value == *self.pool_settings.automatic_sharding_key.as_ref().unwrap();
== *self.pool_settings.automatic_sharding_key.as_ref().unwrap(); }
} _ => (),
_ => (), };
};
match op { match op {
BinaryOperator::Eq => (), BinaryOperator::Eq => (),
BinaryOperator::Or => (), BinaryOperator::Or => (),
BinaryOperator::And => (), BinaryOperator::And => (),
_ => { _ => {
// TODO: support other operators than equality. // TODO: support other operators than equality.
debug!("Unsupported operation: {:?}", op); debug!("Unsupported operation: {:?}", op);
return Vec::new(); return Vec::new();
} }
}; };
match &**right { match &**right {
Expr::BinaryOp { .. } => result.extend(self.selection_parser(&right)), Expr::BinaryOp { .. } => result.extend(self.selection_parser(right)),
Expr::Value(Value::Number(value, ..)) => { Expr::Value(Value::Number(value, ..)) => {
if found { if found {
match value.parse::<i64>() { match value.parse::<i64>() {
Ok(value) => result.push(value), Ok(value) => result.push(value),
Err(_) => { Err(_) => {
debug!("Sharding key was not an integer: {}", value); debug!("Sharding key was not an integer: {}", value);
} }
}; };
}
} }
_ => (), }
}; _ => (),
} };
}
_ => (),
};
debug!("Sharding keys found: {:?}", result); debug!("Sharding keys found: {:?}", result);
@@ -438,7 +430,7 @@ impl QueryRouter {
SetExpr::Select(select) => { SetExpr::Select(select) => {
match &select.selection { match &select.selection {
Some(selection) => { Some(selection) => {
let sharding_keys = self.selection_parser(&selection); let sharding_keys = self.selection_parser(selection);
// TODO: Add support for prepared statements here. // TODO: Add support for prepared statements here.
// This should just give us the position of the value in the `B` message. // This should just give us the position of the value in the `B` message.
@@ -484,10 +476,7 @@ impl QueryRouter {
/// Get desired shard we should be talking to. /// Get desired shard we should be talking to.
pub fn shard(&self) -> usize { pub fn shard(&self) -> usize {
match self.active_shard { self.active_shard.unwrap_or(0)
Some(shard) => shard,
None => 0,
}
} }
pub fn set_shard(&mut self, shard: usize) { pub fn set_shard(&mut self, shard: usize) {
@@ -531,7 +520,7 @@ mod test {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None); assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None);
assert_eq!(qr.query_parser_enabled(), true); assert!(qr.query_parser_enabled());
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);
@@ -656,9 +645,9 @@ mod test {
for (i, test) in tests.iter().enumerate() { for (i, test) in tests.iter().enumerate() {
if !list[matches[i]].is_match(test) { if !list[matches[i]].is_match(test) {
println!("{} does not match {}", test, list[matches[i]]); println!("{} does not match {}", test, list[matches[i]]);
assert!(false); panic!();
} }
assert_eq!(set.matches(test).into_iter().collect::<Vec<_>>().len(), 1); assert_eq!(set.matches(test).into_iter().count(), 1);
} }
let bad = [ let bad = [
@@ -667,7 +656,7 @@ mod test {
]; ];
for query in &bad { for query in &bad {
assert_eq!(set.matches(query).into_iter().collect::<Vec<_>>().len(), 0); assert_eq!(set.matches(query).into_iter().count(), 0);
} }
} }
@@ -760,11 +749,11 @@ mod test {
assert_eq!(qr.role(), None); assert_eq!(qr.role(), None);
let query = simple_query("INSERT INTO test_table VALUES (1)"); let query = simple_query("INSERT INTO test_table VALUES (1)");
assert_eq!(qr.infer(query), true); assert!(qr.infer(query));
assert_eq!(qr.role(), Some(Role::Primary)); assert_eq!(qr.role(), Some(Role::Primary));
let query = simple_query("SELECT * FROM test_table"); let query = simple_query("SELECT * FROM test_table");
assert_eq!(qr.infer(query), true); assert!(qr.infer(query));
assert_eq!(qr.role(), Some(Role::Replica)); assert_eq!(qr.role(), Some(Role::Replica));
assert!(qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
@@ -798,8 +787,8 @@ mod test {
assert_eq!(qr.active_role, None); assert_eq!(qr.active_role, None);
assert_eq!(qr.active_shard, None); assert_eq!(qr.active_shard, None);
assert_eq!(qr.query_parser_enabled(), true); assert!(qr.query_parser_enabled());
assert_eq!(qr.primary_reads_enabled(), false); assert!(!qr.primary_reads_enabled());
let q1 = simple_query("SET SERVER ROLE TO 'primary'"); let q1 = simple_query("SET SERVER ROLE TO 'primary'");
assert!(qr.try_execute_command(q1) != None); assert!(qr.try_execute_command(q1) != None);
@@ -807,7 +796,7 @@ mod test {
let q2 = simple_query("SET SERVER ROLE TO 'default'"); let q2 = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(q2) != None); assert!(qr.try_execute_command(q2) != None);
assert_eq!(qr.active_role.unwrap(), pool_settings.clone().default_role); assert_eq!(qr.active_role.unwrap(), pool_settings.default_role);
// Here we go :) // Here we go :)
let q3 = simple_query("SELECT * FROM test WHERE id = 5 AND values IN (1, 2, 3)"); let q3 = simple_query("SELECT * FROM test WHERE id = 5 AND values IN (1, 2, 3)");

View File

@@ -57,7 +57,7 @@ impl ScramSha256 {
/// Used for testing. /// Used for testing.
pub fn from_nonce(password: &str, nonce: &str) -> ScramSha256 { pub fn from_nonce(password: &str, nonce: &str) -> ScramSha256 {
let message = BytesMut::from(&format!("{}n=,r={}", "n,,", nonce).as_bytes()[..]); let message = BytesMut::from(format!("{}n=,r={}", "n,,", nonce).as_bytes());
ScramSha256 { ScramSha256 {
password: password.to_string(), password: password.to_string(),
@@ -87,7 +87,7 @@ impl ScramSha256 {
}; };
let salted_password = Self::hi( let salted_password = Self::hi(
&normalize(&self.password.as_bytes()[..]), &normalize(self.password.as_bytes()),
&salt, &salt,
server_message.iterations, server_message.iterations,
); );
@@ -181,7 +181,7 @@ impl ScramSha256 {
match hmac.verify_slice(&verifier) { match hmac.verify_slice(&verifier) {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(_) => return Err(Error::ServerError), Err(_) => Err(Error::ServerError),
} }
} }
@@ -220,7 +220,7 @@ impl Message {
/// Parse the server SASL challenge. /// Parse the server SASL challenge.
fn parse(message: &BytesMut) -> Result<Message, Error> { fn parse(message: &BytesMut) -> Result<Message, Error> {
let parts = String::from_utf8_lossy(&message[..]) let parts = String::from_utf8_lossy(&message[..])
.split(",") .split(',')
.map(|s| s.to_string()) .map(|s| s.to_string())
.collect::<Vec<String>>(); .collect::<Vec<String>>();
@@ -268,7 +268,7 @@ mod test {
#[test] #[test]
fn parse_server_first_message() { fn parse_server_first_message() {
let message = BytesMut::from( let message = BytesMut::from(
&"r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096".as_bytes()[..], "r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096".as_bytes(),
); );
let message = Message::parse(&message).unwrap(); let message = Message::parse(&message).unwrap();
assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j"); assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
@@ -279,7 +279,7 @@ mod test {
#[test] #[test]
fn parse_server_last_message() { fn parse_server_last_message() {
let f = FinalMessage::parse(&BytesMut::from( let f = FinalMessage::parse(&BytesMut::from(
&"v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw".as_bytes()[..], "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw".as_bytes(),
)) ))
.unwrap(); .unwrap();
assert_eq!( assert_eq!(
@@ -309,12 +309,12 @@ mod test {
assert_eq!(std::str::from_utf8(&message).unwrap(), client_first); assert_eq!(std::str::from_utf8(&message).unwrap(), client_first);
let result = scram let result = scram
.update(&BytesMut::from(&server_first.as_bytes()[..])) .update(&BytesMut::from(server_first.as_bytes()))
.unwrap(); .unwrap();
assert_eq!(std::str::from_utf8(&result).unwrap(), client_final); assert_eq!(std::str::from_utf8(&result).unwrap(), client_final);
scram scram
.finish(&BytesMut::from(&server_final.as_bytes()[..])) .finish(&BytesMut::from(server_final.as_bytes()))
.unwrap(); .unwrap();
} }
} }

View File

@@ -175,7 +175,7 @@ impl Server {
+ sasl_response.len() as i32, // length of SASL response + sasl_response.len() as i32, // length of SASL response
); );
res.put_slice(&format!("{}\0", SCRAM_SHA_256).as_bytes()[..]); res.put_slice(format!("{}\0", SCRAM_SHA_256).as_bytes());
res.put_i32(sasl_response.len() as i32); res.put_i32(sasl_response.len() as i32);
res.put(sasl_response); res.put(sasl_response);
@@ -315,19 +315,19 @@ impl Server {
let mut server = Server { let mut server = Server {
address: address.clone(), address: address.clone(),
read: BufReader::new(read), read: BufReader::new(read),
write: write, write,
buffer: BytesMut::with_capacity(8196), buffer: BytesMut::with_capacity(8196),
server_info: server_info, server_info,
server_id: server_id, server_id,
process_id: process_id, process_id,
secret_key: secret_key, secret_key,
in_transaction: false, in_transaction: false,
data_available: false, data_available: false,
bad: false, bad: false,
needs_cleanup: false, needs_cleanup: false,
client_server_map: client_server_map, client_server_map,
connected_at: chrono::offset::Utc::now().naive_utc(), connected_at: chrono::offset::Utc::now().naive_utc(),
stats: stats, stats,
application_name: String::new(), application_name: String::new(),
last_activity: SystemTime::now(), last_activity: SystemTime::now(),
}; };
@@ -371,7 +371,7 @@ impl Server {
bytes.put_i32(process_id); bytes.put_i32(process_id);
bytes.put_i32(secret_key); bytes.put_i32(secret_key);
Ok(write_all(&mut stream, bytes).await?) write_all(&mut stream, bytes).await
} }
/// Send messages to the server from the client. /// Send messages to the server from the client.
@@ -616,7 +616,7 @@ impl Server {
self.needs_cleanup = false; self.needs_cleanup = false;
} }
return Ok(()); Ok(())
} }
/// A shorthand for `SET application_name = $1`. /// A shorthand for `SET application_name = $1`.
@@ -631,7 +631,7 @@ impl Server {
.query(&format!("SET application_name = '{}'", name)) .query(&format!("SET application_name = '{}'", name))
.await?); .await?);
self.needs_cleanup = needs_cleanup_before; self.needs_cleanup = needs_cleanup_before;
return result; result
} else { } else {
Ok(()) Ok(())
} }

View File

@@ -133,7 +133,7 @@ impl Sharder {
#[inline] #[inline]
fn combine(mut a: u64, b: u64) -> u64 { fn combine(mut a: u64, b: u64) -> u64 {
a ^= b a ^= b
.wrapping_add(0x49a0f4dd15e5a8e3 as u64) .wrapping_add(0x49a0f4dd15e5a8e3_u64)
.wrapping_add(a << 54) .wrapping_add(a << 54)
.wrapping_add(a >> 7); .wrapping_add(a >> 7);
a a
@@ -141,7 +141,7 @@ impl Sharder {
#[inline] #[inline]
fn pg_u32_hash(k: u32) -> u64 { fn pg_u32_hash(k: u32) -> u64 {
let mut a: u32 = 0x9e3779b9 as u32 + std::mem::size_of::<u32>() as u32 + 3923095 as u32; let mut a: u32 = 0x9e3779b9_u32 + std::mem::size_of::<u32>() as u32 + 3923095_u32;
let mut b = a; let mut b = a;
let c = a; let c = a;

View File

@@ -245,7 +245,7 @@ impl Default for Reporter {
impl Reporter { impl Reporter {
/// Create a new Reporter instance. /// Create a new Reporter instance.
pub fn new(tx: Sender<Event>) -> Reporter { pub fn new(tx: Sender<Event>) -> Reporter {
Reporter { tx: tx } Reporter { tx }
} }
/// Send statistics to the task keeping track of stats. /// Send statistics to the task keeping track of stats.
@@ -338,9 +338,9 @@ impl Reporter {
let event = Event { let event = Event {
name: EventName::ClientRegistered { name: EventName::ClientRegistered {
client_id, client_id,
pool_name: pool_name.clone(), pool_name,
username: username.clone(), username,
application_name: app_name.clone(), application_name: app_name,
}, },
value: 1, value: 1,
}; };
@@ -582,7 +582,7 @@ impl Collector {
let address_stats = address_stat_lookup let address_stats = address_stat_lookup
.entry(server_info.address_id) .entry(server_info.address_id)
.or_insert(HashMap::default()); .or_insert_with(HashMap::default);
let counter = address_stats let counter = address_stats
.entry("total_query_count".to_string()) .entry("total_query_count".to_string())
.or_insert(0); .or_insert(0);
@@ -618,7 +618,7 @@ impl Collector {
let address_stats = address_stat_lookup let address_stats = address_stat_lookup
.entry(server_info.address_id) .entry(server_info.address_id)
.or_insert(HashMap::default()); .or_insert_with(HashMap::default);
let counter = address_stats let counter = address_stats
.entry("total_xact_count".to_string()) .entry("total_xact_count".to_string())
.or_insert(0); .or_insert(0);
@@ -636,7 +636,7 @@ impl Collector {
let address_stats = address_stat_lookup let address_stats = address_stat_lookup
.entry(server_info.address_id) .entry(server_info.address_id)
.or_insert(HashMap::default()); .or_insert_with(HashMap::default);
let counter = let counter =
address_stats.entry("total_sent".to_string()).or_insert(0); address_stats.entry("total_sent".to_string()).or_insert(0);
*counter += stat.value; *counter += stat.value;
@@ -653,7 +653,7 @@ impl Collector {
let address_stats = address_stat_lookup let address_stats = address_stat_lookup
.entry(server_info.address_id) .entry(server_info.address_id)
.or_insert(HashMap::default()); .or_insert_with(HashMap::default);
let counter = address_stats let counter = address_stats
.entry("total_received".to_string()) .entry("total_received".to_string())
.or_insert(0); .or_insert(0);
@@ -683,7 +683,7 @@ impl Collector {
let address_stats = address_stat_lookup let address_stats = address_stat_lookup
.entry(server_info.address_id) .entry(server_info.address_id)
.or_insert(HashMap::default()); .or_insert_with(HashMap::default);
let counter = address_stats let counter = address_stats
.entry("total_wait_time".to_string()) .entry("total_wait_time".to_string())
.or_insert(0); .or_insert(0);
@@ -694,7 +694,7 @@ impl Collector {
server_info.pool_name.clone(), server_info.pool_name.clone(),
server_info.username.clone(), server_info.username.clone(),
)) ))
.or_insert(HashMap::default()); .or_insert_with(HashMap::default);
// We record max wait in microseconds, we do the pgbouncer second/microsecond split on admin // We record max wait in microseconds, we do the pgbouncer second/microsecond split on admin
let old_microseconds = let old_microseconds =
@@ -750,7 +750,7 @@ impl Collector {
// Update address aggregation stats // Update address aggregation stats
let address_stats = address_stat_lookup let address_stats = address_stat_lookup
.entry(address_id) .entry(address_id)
.or_insert(HashMap::default()); .or_insert_with(HashMap::default);
let counter = address_stats.entry("total_errors".to_string()).or_insert(0); let counter = address_stats.entry("total_errors".to_string()).or_insert(0);
*counter += stat.value; *counter += stat.value;
} }
@@ -770,7 +770,7 @@ impl Collector {
// Update address aggregation stats // Update address aggregation stats
let address_stats = address_stat_lookup let address_stats = address_stat_lookup
.entry(address_id) .entry(address_id)
.or_insert(HashMap::default()); .or_insert_with(HashMap::default);
let counter = address_stats.entry("total_errors".to_string()).or_insert(0); let counter = address_stats.entry("total_errors".to_string()).or_insert(0);
*counter += stat.value; *counter += stat.value;
} }
@@ -891,7 +891,7 @@ impl Collector {
} => { } => {
let pool_stats = pool_stat_lookup let pool_stats = pool_stat_lookup
.entry((pool_name.clone(), username.clone())) .entry((pool_name.clone(), username.clone()))
.or_insert(HashMap::default()); .or_insert_with(HashMap::default);
// These are re-calculated every iteration of the loop, so we don't want to add values // These are re-calculated every iteration of the loop, so we don't want to add values
// from the last iteration. // from the last iteration.
@@ -964,17 +964,17 @@ impl Collector {
// Clear maxwait after reporting // Clear maxwait after reporting
pool_stat_lookup pool_stat_lookup
.entry((pool_name.clone(), username.clone())) .entry((pool_name.clone(), username.clone()))
.or_insert(HashMap::default()) .or_insert_with(HashMap::default)
.insert("maxwait_us".to_string(), 0); .insert("maxwait_us".to_string(), 0);
} }
EventName::UpdateAverages { address_id } => { EventName::UpdateAverages { address_id } => {
let stats = address_stat_lookup let stats = address_stat_lookup
.entry(address_id) .entry(address_id)
.or_insert(HashMap::default()); .or_insert_with(HashMap::default);
let old_stats = address_old_stat_lookup let old_stats = address_old_stat_lookup
.entry(address_id) .entry(address_id)
.or_insert(HashMap::default()); .or_insert_with(HashMap::default);
// Calculate averages // Calculate averages
for stat in &[ for stat in &[

View File

@@ -30,12 +30,12 @@ impl Tls {
pub fn new() -> Result<Self, Error> { pub fn new() -> Result<Self, Error> {
let config = get_config(); let config = get_config();
let certs = match load_certs(&Path::new(&config.general.tls_certificate.unwrap())) { let certs = match load_certs(Path::new(&config.general.tls_certificate.unwrap())) {
Ok(certs) => certs, Ok(certs) => certs,
Err(_) => return Err(Error::TlsError), Err(_) => return Err(Error::TlsError),
}; };
let mut keys = match load_keys(&Path::new(&config.general.tls_private_key.unwrap())) { let mut keys = match load_keys(Path::new(&config.general.tls_private_key.unwrap())) {
Ok(keys) => keys, Ok(keys) => keys,
Err(_) => return Err(Error::TlsError), Err(_) => return Err(Error::TlsError),
}; };