mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
chore: make clippy lint happy (#225)
* chore: make clippy happy * chore: cargo fmt * chore: cargo fmt
This commit is contained in:
@@ -22,7 +22,7 @@ pub fn generate_server_info_for_admin() -> BytesMut {
|
||||
server_info.put(server_parameter_message("server_version", VERSION));
|
||||
server_info.put(server_parameter_message("DateStyle", "ISO, MDY"));
|
||||
|
||||
return server_info;
|
||||
server_info
|
||||
}
|
||||
|
||||
/// Handle admin client.
|
||||
@@ -179,7 +179,7 @@ where
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
res.put(row_description(&vec![("version", DataType::Text)]));
|
||||
res.put(data_row(&vec![format!("PgCat {}", VERSION).to_string()]));
|
||||
res.put(data_row(&vec![format!("PgCat {}", VERSION)]));
|
||||
res.put(command_complete("SHOW"));
|
||||
|
||||
res.put_u8(b'Z');
|
||||
|
||||
@@ -377,7 +377,7 @@ where
|
||||
|
||||
let admin = ["pgcat", "pgbouncer"]
|
||||
.iter()
|
||||
.filter(|db| *db == &pool_name)
|
||||
.filter(|db| *db == pool_name)
|
||||
.count()
|
||||
== 1;
|
||||
|
||||
@@ -389,7 +389,7 @@ where
|
||||
);
|
||||
error_response_terminal(
|
||||
&mut write,
|
||||
&format!("terminating connection due to administrator command"),
|
||||
"terminating connection due to administrator command",
|
||||
)
|
||||
.await?;
|
||||
return Err(Error::ShuttingDown);
|
||||
@@ -446,7 +446,7 @@ where
|
||||
}
|
||||
// Authenticate normal user.
|
||||
else {
|
||||
let pool = match get_pool(&pool_name, &username) {
|
||||
let pool = match get_pool(pool_name, username) {
|
||||
Some(pool) => pool,
|
||||
None => {
|
||||
error_response(
|
||||
@@ -464,7 +464,7 @@ where
|
||||
};
|
||||
|
||||
// Compare server and client hashes.
|
||||
let password_hash = md5_hash_password(&username, &pool.settings.user.password, &salt);
|
||||
let password_hash = md5_hash_password(username, &pool.settings.user.password, &salt);
|
||||
|
||||
if password_hash != password_response {
|
||||
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name);
|
||||
@@ -487,9 +487,9 @@ where
|
||||
|
||||
trace!("Startup OK");
|
||||
|
||||
return Ok(Client {
|
||||
Ok(Client {
|
||||
read: BufReader::new(read),
|
||||
write: write,
|
||||
write,
|
||||
addr,
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
cancel_mode: false,
|
||||
@@ -498,8 +498,8 @@ where
|
||||
secret_key,
|
||||
client_server_map,
|
||||
parameters: parameters.clone(),
|
||||
stats: stats,
|
||||
admin: admin,
|
||||
stats,
|
||||
admin,
|
||||
last_address_id: None,
|
||||
last_server_id: None,
|
||||
pool_name: pool_name.clone(),
|
||||
@@ -507,7 +507,7 @@ where
|
||||
application_name: application_name.to_string(),
|
||||
shutdown,
|
||||
connected_to_server: false,
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle cancel request.
|
||||
@@ -521,9 +521,9 @@ where
|
||||
) -> Result<Client<S, T>, Error> {
|
||||
let process_id = bytes.get_i32();
|
||||
let secret_key = bytes.get_i32();
|
||||
return Ok(Client {
|
||||
Ok(Client {
|
||||
read: BufReader::new(read),
|
||||
write: write,
|
||||
write,
|
||||
addr,
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
cancel_mode: true,
|
||||
@@ -541,7 +541,7 @@ where
|
||||
application_name: String::from("undefined"),
|
||||
shutdown,
|
||||
connected_to_server: false,
|
||||
});
|
||||
})
|
||||
}
|
||||
|
||||
/// Handle a connected and authenticated client.
|
||||
@@ -557,12 +557,9 @@ where
|
||||
// Drop the mutex as soon as possible.
|
||||
// We found the server the client is using for its query
|
||||
// that it wants to cancel.
|
||||
Some((process_id, secret_key, address, port)) => (
|
||||
process_id.clone(),
|
||||
secret_key.clone(),
|
||||
address.clone(),
|
||||
*port,
|
||||
),
|
||||
Some((process_id, secret_key, address, port)) => {
|
||||
(*process_id, *secret_key, address.clone(), *port)
|
||||
}
|
||||
|
||||
// The client doesn't know / got the wrong server,
|
||||
// we're closing the connection for security reasons.
|
||||
@@ -573,7 +570,7 @@ where
|
||||
// Opens a new separate connection to the server, sends the backend_id
|
||||
// and secret_key and then closes it for security reasons. No other interactions
|
||||
// take place.
|
||||
return Ok(Server::cancel(&address, port, process_id, secret_key).await?);
|
||||
return Server::cancel(&address, port, process_id, secret_key).await;
|
||||
}
|
||||
|
||||
// The query router determines where the query is going to go,
|
||||
@@ -606,7 +603,7 @@ where
|
||||
if !self.admin {
|
||||
error_response_terminal(
|
||||
&mut self.write,
|
||||
&format!("terminating connection due to administrator command")
|
||||
"terminating connection due to administrator command"
|
||||
).await?;
|
||||
return Ok(())
|
||||
}
|
||||
@@ -998,14 +995,14 @@ where
|
||||
) -> Result<(), Error> {
|
||||
debug!("Sending {} to server", code);
|
||||
|
||||
self.send_server_message(server, message, &address, &pool)
|
||||
self.send_server_message(server, message, address, pool)
|
||||
.await?;
|
||||
|
||||
let query_start = Instant::now();
|
||||
// Read all data the server has to offer, which can be multiple messages
|
||||
// buffered in 8196 bytes chunks.
|
||||
loop {
|
||||
let response = self.receive_server_message(server, &address, &pool).await?;
|
||||
let response = self.receive_server_message(server, address, pool).await?;
|
||||
|
||||
match write_all_half(&mut self.write, response).await {
|
||||
Ok(_) => (),
|
||||
|
||||
@@ -9,7 +9,6 @@ use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use tokio::fs::File;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use toml;
|
||||
|
||||
use crate::errors::Error;
|
||||
use crate::pool::{ClientServerMap, ConnectionPool};
|
||||
@@ -353,7 +352,7 @@ impl Shard {
|
||||
let mut dup_check = HashSet::new();
|
||||
let mut primary_count = 0;
|
||||
|
||||
if self.servers.len() == 0 {
|
||||
if self.servers.is_empty() {
|
||||
error!("Shard {} has no servers configured", self.database);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
@@ -362,10 +361,9 @@ impl Shard {
|
||||
dup_check.insert(server);
|
||||
|
||||
// Check that we define only zero or one primary.
|
||||
match server.role {
|
||||
Role::Primary => primary_count += 1,
|
||||
_ => (),
|
||||
};
|
||||
if server.role == Role::Primary {
|
||||
primary_count += 1
|
||||
}
|
||||
}
|
||||
|
||||
if primary_count > 1 {
|
||||
@@ -605,22 +603,17 @@ impl Config {
|
||||
// Validate TLS!
|
||||
match self.general.tls_certificate.clone() {
|
||||
Some(tls_certificate) => {
|
||||
match load_certs(&Path::new(&tls_certificate)) {
|
||||
match load_certs(Path::new(&tls_certificate)) {
|
||||
Ok(_) => {
|
||||
// Cert is okay, but what about the private key?
|
||||
match self.general.tls_private_key.clone() {
|
||||
Some(tls_private_key) => {
|
||||
match load_keys(&Path::new(&tls_private_key)) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!(
|
||||
"tls_private_key is incorrectly configured: {:?}",
|
||||
err
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("tls_private_key is incorrectly configured: {:?}", err);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
None => {
|
||||
error!("tls_certificate is set, but the tls_private_key is not");
|
||||
@@ -638,7 +631,7 @@ impl Config {
|
||||
None => (),
|
||||
};
|
||||
|
||||
for (_, pool) in &mut self.pools {
|
||||
for pool in self.pools.values_mut() {
|
||||
pool.validate()?;
|
||||
}
|
||||
|
||||
|
||||
23
src/main.rs
23
src/main.rs
@@ -75,7 +75,6 @@ mod stats;
|
||||
mod tls;
|
||||
|
||||
use crate::config::{get_config, reload_config, VERSION};
|
||||
use crate::errors::Error;
|
||||
use crate::pool::{ClientServerMap, ConnectionPool};
|
||||
use crate::prometheus::start_metric_server;
|
||||
use crate::stats::{Collector, Reporter, REPORTER};
|
||||
@@ -171,13 +170,10 @@ async fn main() {
|
||||
if config.general.autoreload {
|
||||
info!("Automatically reloading config");
|
||||
|
||||
match reload_config(autoreload_client_server_map.clone()).await {
|
||||
Ok(changed) => {
|
||||
if changed {
|
||||
get_config().show()
|
||||
}
|
||||
if let Ok(changed) = reload_config(autoreload_client_server_map.clone()).await {
|
||||
if changed {
|
||||
get_config().show()
|
||||
}
|
||||
Err(_) => (),
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -202,10 +198,7 @@ async fn main() {
|
||||
_ = sighup_signal.recv() => {
|
||||
info!("Reloading config");
|
||||
|
||||
match reload_config(client_server_map.clone()).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => (),
|
||||
};
|
||||
_ = reload_config(client_server_map.clone()).await;
|
||||
|
||||
get_config().show();
|
||||
},
|
||||
@@ -278,14 +271,6 @@ async fn main() {
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
match err {
|
||||
// Don't count the clients we rejected.
|
||||
Error::ShuttingDown => (),
|
||||
_ => {
|
||||
// drain_tx.send(-1).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
warn!("Client disconnected with error {:?}", err);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -38,7 +38,7 @@ where
|
||||
auth_ok.put_i32(8);
|
||||
auth_ok.put_i32(0);
|
||||
|
||||
Ok(write_all(stream, auth_ok).await?)
|
||||
write_all(stream, auth_ok).await
|
||||
}
|
||||
|
||||
/// Generate md5 password challenge.
|
||||
@@ -79,7 +79,7 @@ where
|
||||
key_data.put_i32(backend_id);
|
||||
key_data.put_i32(secret_key);
|
||||
|
||||
Ok(write_all(stream, key_data).await?)
|
||||
write_all(stream, key_data).await
|
||||
}
|
||||
|
||||
/// Construct a `Q`: Query message.
|
||||
@@ -88,7 +88,7 @@ pub fn simple_query(query: &str) -> BytesMut {
|
||||
let query = format!("{}\0", query);
|
||||
|
||||
res.put_i32(query.len() as i32 + 4);
|
||||
res.put_slice(&query.as_bytes());
|
||||
res.put_slice(query.as_bytes());
|
||||
|
||||
res
|
||||
}
|
||||
@@ -106,7 +106,7 @@ where
|
||||
bytes.put_i32(5);
|
||||
bytes.put_u8(b'I'); // Idle
|
||||
|
||||
Ok(write_all(stream, bytes).await?)
|
||||
write_all(stream, bytes).await
|
||||
}
|
||||
|
||||
/// Send the startup packet the server. We're pretending we're a Pg client.
|
||||
@@ -118,12 +118,12 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
|
||||
|
||||
// User
|
||||
bytes.put(&b"user\0"[..]);
|
||||
bytes.put_slice(&user.as_bytes());
|
||||
bytes.put_slice(user.as_bytes());
|
||||
bytes.put_u8(0);
|
||||
|
||||
// Database
|
||||
bytes.put(&b"database\0"[..]);
|
||||
bytes.put_slice(&database.as_bytes());
|
||||
bytes.put_slice(database.as_bytes());
|
||||
bytes.put_u8(0);
|
||||
bytes.put_u8(0); // Null terminator
|
||||
|
||||
@@ -136,7 +136,7 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
|
||||
|
||||
match stream.write_all(&startup).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => Err(Error::SocketError),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -155,7 +155,7 @@ pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Erro
|
||||
c = bytes.get_u8();
|
||||
}
|
||||
|
||||
if tmp.len() > 0 {
|
||||
if !tmp.is_empty() {
|
||||
buf.push(tmp.clone());
|
||||
tmp.clear();
|
||||
}
|
||||
@@ -234,7 +234,7 @@ where
|
||||
message.put_i32(password.len() as i32 + 4);
|
||||
message.put_slice(&password[..]);
|
||||
|
||||
Ok(write_all(stream, message).await?)
|
||||
write_all(stream, message).await
|
||||
}
|
||||
|
||||
/// Implements a response to our custom `SET SHARDING KEY`
|
||||
@@ -292,7 +292,7 @@ where
|
||||
|
||||
// The short error message.
|
||||
error.put_u8(b'M');
|
||||
error.put_slice(&format!("{}\0", message).as_bytes());
|
||||
error.put_slice(format!("{}\0", message).as_bytes());
|
||||
|
||||
// No more fields follow.
|
||||
error.put_u8(0);
|
||||
@@ -304,7 +304,7 @@ where
|
||||
res.put_i32(error.len() as i32 + 4);
|
||||
res.put(error);
|
||||
|
||||
Ok(write_all_half(stream, res).await?)
|
||||
write_all_half(stream, res).await
|
||||
}
|
||||
|
||||
pub async fn wrong_password<S>(stream: &mut S, user: &str) -> Result<(), Error>
|
||||
@@ -327,7 +327,7 @@ where
|
||||
|
||||
// The short error message.
|
||||
error.put_u8(b'M');
|
||||
error.put_slice(&format!("password authentication failed for user \"{}\"\0", user).as_bytes());
|
||||
error.put_slice(format!("password authentication failed for user \"{}\"\0", user).as_bytes());
|
||||
|
||||
// No more fields follow.
|
||||
error.put_u8(0);
|
||||
@@ -379,7 +379,7 @@ pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
|
||||
|
||||
for (name, data_type) in columns {
|
||||
// Column name
|
||||
row_desc.put_slice(&format!("{}\0", name).as_bytes());
|
||||
row_desc.put_slice(format!("{}\0", name).as_bytes());
|
||||
|
||||
// Doesn't belong to any table
|
||||
row_desc.put_i32(0);
|
||||
@@ -423,7 +423,7 @@ pub fn data_row(row: &Vec<String>) -> BytesMut {
|
||||
for column in row {
|
||||
let column = column.as_bytes();
|
||||
data_row.put_i32(column.len() as i32);
|
||||
data_row.put_slice(&column);
|
||||
data_row.put_slice(column);
|
||||
}
|
||||
|
||||
res.put_u8(b'D');
|
||||
@@ -450,7 +450,7 @@ where
|
||||
{
|
||||
match stream.write_all(&buf).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => Err(Error::SocketError),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -461,7 +461,7 @@ where
|
||||
{
|
||||
match stream.write_all(&buf).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => Err(Error::SocketError),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -510,5 +510,5 @@ pub fn server_parameter_message(key: &str, value: &str) -> BytesMut {
|
||||
server_info.put_slice(value.as_bytes());
|
||||
server_info.put_bytes(0, 1);
|
||||
|
||||
return server_info;
|
||||
server_info
|
||||
}
|
||||
|
||||
46
src/pool.rs
46
src/pool.rs
@@ -140,18 +140,18 @@ impl ConnectionPool {
|
||||
let changed = pools_hash.insert(pool_config.clone());
|
||||
|
||||
// There is one pool per database/user pair.
|
||||
for (_, user) in &pool_config.users {
|
||||
for user in pool_config.users.values() {
|
||||
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
|
||||
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
|
||||
if !changed {
|
||||
match get_pool(&pool_name, &user.username) {
|
||||
match get_pool(pool_name, &user.username) {
|
||||
Some(pool) => {
|
||||
info!(
|
||||
"[pool: {}][user: {}] has not changed",
|
||||
pool_name, user.username
|
||||
);
|
||||
new_pools.insert(
|
||||
PoolIdentifier::new(&pool_name, &user.username),
|
||||
PoolIdentifier::new(pool_name, &user.username),
|
||||
pool.clone(),
|
||||
);
|
||||
continue;
|
||||
@@ -172,7 +172,6 @@ impl ConnectionPool {
|
||||
.shards
|
||||
.clone()
|
||||
.into_keys()
|
||||
.map(|x| x.to_string())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
// Sort by shard number to ensure consistency.
|
||||
@@ -182,10 +181,9 @@ impl ConnectionPool {
|
||||
let shard = &pool_config.shards[shard_idx];
|
||||
let mut pools = Vec::new();
|
||||
let mut servers = Vec::new();
|
||||
let mut address_index = 0;
|
||||
let mut replica_number = 0;
|
||||
|
||||
for server in shard.servers.iter() {
|
||||
for (address_index, server) in shard.servers.iter().enumerate() {
|
||||
let address = Address {
|
||||
id: address_id,
|
||||
database: shard.database.clone(),
|
||||
@@ -200,7 +198,6 @@ impl ConnectionPool {
|
||||
};
|
||||
|
||||
address_id += 1;
|
||||
address_index += 1;
|
||||
|
||||
if server.role == Role::Replica {
|
||||
replica_number += 1;
|
||||
@@ -240,7 +237,7 @@ impl ConnectionPool {
|
||||
|
||||
let mut pool = ConnectionPool {
|
||||
databases: shards,
|
||||
addresses: addresses,
|
||||
addresses,
|
||||
banlist: Arc::new(RwLock::new(banlist)),
|
||||
stats: get_reporter(),
|
||||
server_info: BytesMut::new(),
|
||||
@@ -255,7 +252,7 @@ impl ConnectionPool {
|
||||
"primary" => Some(Role::Primary),
|
||||
_ => unreachable!(),
|
||||
},
|
||||
query_parser_enabled: pool_config.query_parser_enabled.clone(),
|
||||
query_parser_enabled: pool_config.query_parser_enabled,
|
||||
primary_reads_enabled: pool_config.primary_reads_enabled,
|
||||
sharding_function: pool_config.sharding_function,
|
||||
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
|
||||
@@ -273,7 +270,7 @@ impl ConnectionPool {
|
||||
};
|
||||
|
||||
// There is one pool per database/user pair.
|
||||
new_pools.insert(PoolIdentifier::new(&pool_name, &user.username), pool);
|
||||
new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -304,7 +301,7 @@ impl ConnectionPool {
|
||||
let server = &*proxy;
|
||||
let server_info = server.server_info();
|
||||
|
||||
if server_infos.len() > 0 {
|
||||
if !server_infos.is_empty() {
|
||||
// Compare against the last server checked.
|
||||
if server_info != server_infos[server_infos.len() - 1] {
|
||||
warn!(
|
||||
@@ -320,7 +317,7 @@ impl ConnectionPool {
|
||||
|
||||
// TODO: compare server information to make sure
|
||||
// all shards are running identical configurations.
|
||||
if server_infos.len() == 0 {
|
||||
if server_infos.is_empty() {
|
||||
return Err(Error::AllServersDown);
|
||||
}
|
||||
|
||||
@@ -356,7 +353,7 @@ impl ConnectionPool {
|
||||
None => break,
|
||||
};
|
||||
|
||||
if self.is_banned(&address, role) {
|
||||
if self.is_banned(address, role) {
|
||||
debug!("Address {:?} is banned", address);
|
||||
continue;
|
||||
}
|
||||
@@ -373,7 +370,7 @@ impl ConnectionPool {
|
||||
Ok(conn) => conn,
|
||||
Err(err) => {
|
||||
error!("Banning instance {:?}, error: {:?}", address, err);
|
||||
self.ban(&address, process_id);
|
||||
self.ban(address, process_id);
|
||||
self.stats.client_checkout_error(process_id, address.id);
|
||||
continue;
|
||||
}
|
||||
@@ -428,7 +425,7 @@ impl ConnectionPool {
|
||||
// Don't leave a bad connection in the pool.
|
||||
server.mark_bad();
|
||||
|
||||
self.ban(&address, process_id);
|
||||
self.ban(address, process_id);
|
||||
continue;
|
||||
}
|
||||
},
|
||||
@@ -442,7 +439,7 @@ impl ConnectionPool {
|
||||
// Don't leave a bad connection in the pool.
|
||||
server.mark_bad();
|
||||
|
||||
self.ban(&address, process_id);
|
||||
self.ban(address, process_id);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -575,11 +572,11 @@ impl ServerPool {
|
||||
stats: Reporter,
|
||||
) -> ServerPool {
|
||||
ServerPool {
|
||||
address: address,
|
||||
user: user,
|
||||
address,
|
||||
user,
|
||||
database: database.to_string(),
|
||||
client_server_map: client_server_map,
|
||||
stats: stats,
|
||||
client_server_map,
|
||||
stats,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -638,15 +635,14 @@ impl ManageConnection for ServerPool {
|
||||
|
||||
/// Get the connection pool
|
||||
pub fn get_pool(db: &str, user: &str) -> Option<ConnectionPool> {
|
||||
match (*(*POOLS.load())).get(&PoolIdentifier::new(db, user)) {
|
||||
Some(pool) => Some(pool.clone()),
|
||||
None => None,
|
||||
}
|
||||
(*(*POOLS.load()))
|
||||
.get(&PoolIdentifier::new(db, user))
|
||||
.cloned()
|
||||
}
|
||||
|
||||
/// Get a pointer to all configured pools.
|
||||
pub fn get_all_pools() -> HashMap<PoolIdentifier, ConnectionPool> {
|
||||
return (*(*POOLS.load())).clone();
|
||||
(*(*POOLS.load())).clone()
|
||||
}
|
||||
|
||||
/// How many total servers we have in the config.
|
||||
|
||||
@@ -203,7 +203,7 @@ async fn prometheus_stats(request: Request<Body>) -> Result<Response<Body>, hype
|
||||
pub async fn start_metric_server(http_addr: SocketAddr) {
|
||||
let http_service_factory =
|
||||
make_service_fn(|_conn| async { Ok::<_, hyper::Error>(service_fn(prometheus_stats)) });
|
||||
let server = Server::bind(&http_addr.into()).serve(http_service_factory);
|
||||
let server = Server::bind(&http_addr).serve(http_service_factory);
|
||||
info!(
|
||||
"Exposing prometheus metrics on http://{}/metrics.",
|
||||
http_addr
|
||||
|
||||
@@ -86,10 +86,7 @@ impl QueryRouter {
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
||||
match CUSTOM_SQL_REGEX_SET.set(set) {
|
||||
Ok(_) => true,
|
||||
Err(_) => false,
|
||||
}
|
||||
CUSTOM_SQL_REGEX_SET.set(set).is_ok()
|
||||
}
|
||||
|
||||
/// Create a new instance of the query router.
|
||||
@@ -276,7 +273,6 @@ impl QueryRouter {
|
||||
// Parse (prepared statement)
|
||||
'P' => {
|
||||
let mut start = 0;
|
||||
let mut end;
|
||||
|
||||
// Skip the name of the prepared statement.
|
||||
while buf[start] != 0 && start < buf.len() {
|
||||
@@ -285,7 +281,7 @@ impl QueryRouter {
|
||||
start += 1; // Skip terminating null
|
||||
|
||||
// Find the end of the prepared stmt (\0)
|
||||
end = start;
|
||||
let mut end = start;
|
||||
while buf[end] != 0 && end < buf.len() {
|
||||
end += 1;
|
||||
}
|
||||
@@ -294,7 +290,7 @@ impl QueryRouter {
|
||||
|
||||
debug!("Prepared statement: '{}'", query);
|
||||
|
||||
query.replace("$", "") // Remove placeholders turning them into "values"
|
||||
query.replace('$', "") // Remove placeholders turning them into "values"
|
||||
}
|
||||
|
||||
_ => return false,
|
||||
@@ -312,7 +308,7 @@ impl QueryRouter {
|
||||
|
||||
debug!("AST: {:?}", ast);
|
||||
|
||||
if ast.len() == 0 {
|
||||
if ast.is_empty() {
|
||||
// That's weird, no idea, let's go to primary
|
||||
self.active_role = Some(Role::Primary);
|
||||
return false;
|
||||
@@ -371,50 +367,46 @@ impl QueryRouter {
|
||||
let mut result = Vec::new();
|
||||
let mut found = false;
|
||||
|
||||
match expr {
|
||||
// This parses `sharding_key = 5`. But it's technically
|
||||
// legal to write `5 = sharding_key`. I don't judge the people
|
||||
// who do that, but I think ORMs will still use the first variant,
|
||||
// so we can leave the second as a TODO.
|
||||
Expr::BinaryOp { left, op, right } => {
|
||||
match &**left {
|
||||
Expr::BinaryOp { .. } => result.extend(self.selection_parser(&left)),
|
||||
Expr::Identifier(ident) => {
|
||||
found = ident.value
|
||||
== *self.pool_settings.automatic_sharding_key.as_ref().unwrap();
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
// This parses `sharding_key = 5`. But it's technically
|
||||
// legal to write `5 = sharding_key`. I don't judge the people
|
||||
// who do that, but I think ORMs will still use the first variant,
|
||||
// so we can leave the second as a TODO.
|
||||
if let Expr::BinaryOp { left, op, right } = expr {
|
||||
match &**left {
|
||||
Expr::BinaryOp { .. } => result.extend(self.selection_parser(left)),
|
||||
Expr::Identifier(ident) => {
|
||||
found =
|
||||
ident.value == *self.pool_settings.automatic_sharding_key.as_ref().unwrap();
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
|
||||
match op {
|
||||
BinaryOperator::Eq => (),
|
||||
BinaryOperator::Or => (),
|
||||
BinaryOperator::And => (),
|
||||
_ => {
|
||||
// TODO: support other operators than equality.
|
||||
debug!("Unsupported operation: {:?}", op);
|
||||
return Vec::new();
|
||||
}
|
||||
};
|
||||
match op {
|
||||
BinaryOperator::Eq => (),
|
||||
BinaryOperator::Or => (),
|
||||
BinaryOperator::And => (),
|
||||
_ => {
|
||||
// TODO: support other operators than equality.
|
||||
debug!("Unsupported operation: {:?}", op);
|
||||
return Vec::new();
|
||||
}
|
||||
};
|
||||
|
||||
match &**right {
|
||||
Expr::BinaryOp { .. } => result.extend(self.selection_parser(&right)),
|
||||
Expr::Value(Value::Number(value, ..)) => {
|
||||
if found {
|
||||
match value.parse::<i64>() {
|
||||
Ok(value) => result.push(value),
|
||||
Err(_) => {
|
||||
debug!("Sharding key was not an integer: {}", value);
|
||||
}
|
||||
};
|
||||
}
|
||||
match &**right {
|
||||
Expr::BinaryOp { .. } => result.extend(self.selection_parser(right)),
|
||||
Expr::Value(Value::Number(value, ..)) => {
|
||||
if found {
|
||||
match value.parse::<i64>() {
|
||||
Ok(value) => result.push(value),
|
||||
Err(_) => {
|
||||
debug!("Sharding key was not an integer: {}", value);
|
||||
}
|
||||
};
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
}
|
||||
|
||||
_ => (),
|
||||
};
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
}
|
||||
|
||||
debug!("Sharding keys found: {:?}", result);
|
||||
|
||||
@@ -438,7 +430,7 @@ impl QueryRouter {
|
||||
SetExpr::Select(select) => {
|
||||
match &select.selection {
|
||||
Some(selection) => {
|
||||
let sharding_keys = self.selection_parser(&selection);
|
||||
let sharding_keys = self.selection_parser(selection);
|
||||
|
||||
// TODO: Add support for prepared statements here.
|
||||
// This should just give us the position of the value in the `B` message.
|
||||
@@ -484,10 +476,7 @@ impl QueryRouter {
|
||||
|
||||
/// Get desired shard we should be talking to.
|
||||
pub fn shard(&self) -> usize {
|
||||
match self.active_shard {
|
||||
Some(shard) => shard,
|
||||
None => 0,
|
||||
}
|
||||
self.active_shard.unwrap_or(0)
|
||||
}
|
||||
|
||||
pub fn set_shard(&mut self, shard: usize) {
|
||||
@@ -531,7 +520,7 @@ mod test {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None);
|
||||
assert_eq!(qr.query_parser_enabled(), true);
|
||||
assert!(qr.query_parser_enabled());
|
||||
|
||||
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);
|
||||
|
||||
@@ -656,9 +645,9 @@ mod test {
|
||||
for (i, test) in tests.iter().enumerate() {
|
||||
if !list[matches[i]].is_match(test) {
|
||||
println!("{} does not match {}", test, list[matches[i]]);
|
||||
assert!(false);
|
||||
panic!();
|
||||
}
|
||||
assert_eq!(set.matches(test).into_iter().collect::<Vec<_>>().len(), 1);
|
||||
assert_eq!(set.matches(test).into_iter().count(), 1);
|
||||
}
|
||||
|
||||
let bad = [
|
||||
@@ -667,7 +656,7 @@ mod test {
|
||||
];
|
||||
|
||||
for query in &bad {
|
||||
assert_eq!(set.matches(query).into_iter().collect::<Vec<_>>().len(), 0);
|
||||
assert_eq!(set.matches(query).into_iter().count(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -760,11 +749,11 @@ mod test {
|
||||
assert_eq!(qr.role(), None);
|
||||
|
||||
let query = simple_query("INSERT INTO test_table VALUES (1)");
|
||||
assert_eq!(qr.infer(query), true);
|
||||
assert!(qr.infer(query));
|
||||
assert_eq!(qr.role(), Some(Role::Primary));
|
||||
|
||||
let query = simple_query("SELECT * FROM test_table");
|
||||
assert_eq!(qr.infer(query), true);
|
||||
assert!(qr.infer(query));
|
||||
assert_eq!(qr.role(), Some(Role::Replica));
|
||||
|
||||
assert!(qr.query_parser_enabled());
|
||||
@@ -798,8 +787,8 @@ mod test {
|
||||
|
||||
assert_eq!(qr.active_role, None);
|
||||
assert_eq!(qr.active_shard, None);
|
||||
assert_eq!(qr.query_parser_enabled(), true);
|
||||
assert_eq!(qr.primary_reads_enabled(), false);
|
||||
assert!(qr.query_parser_enabled());
|
||||
assert!(!qr.primary_reads_enabled());
|
||||
|
||||
let q1 = simple_query("SET SERVER ROLE TO 'primary'");
|
||||
assert!(qr.try_execute_command(q1) != None);
|
||||
@@ -807,7 +796,7 @@ mod test {
|
||||
|
||||
let q2 = simple_query("SET SERVER ROLE TO 'default'");
|
||||
assert!(qr.try_execute_command(q2) != None);
|
||||
assert_eq!(qr.active_role.unwrap(), pool_settings.clone().default_role);
|
||||
assert_eq!(qr.active_role.unwrap(), pool_settings.default_role);
|
||||
|
||||
// Here we go :)
|
||||
let q3 = simple_query("SELECT * FROM test WHERE id = 5 AND values IN (1, 2, 3)");
|
||||
|
||||
16
src/scram.rs
16
src/scram.rs
@@ -57,7 +57,7 @@ impl ScramSha256 {
|
||||
|
||||
/// Used for testing.
|
||||
pub fn from_nonce(password: &str, nonce: &str) -> ScramSha256 {
|
||||
let message = BytesMut::from(&format!("{}n=,r={}", "n,,", nonce).as_bytes()[..]);
|
||||
let message = BytesMut::from(format!("{}n=,r={}", "n,,", nonce).as_bytes());
|
||||
|
||||
ScramSha256 {
|
||||
password: password.to_string(),
|
||||
@@ -87,7 +87,7 @@ impl ScramSha256 {
|
||||
};
|
||||
|
||||
let salted_password = Self::hi(
|
||||
&normalize(&self.password.as_bytes()[..]),
|
||||
&normalize(self.password.as_bytes()),
|
||||
&salt,
|
||||
server_message.iterations,
|
||||
);
|
||||
@@ -181,7 +181,7 @@ impl ScramSha256 {
|
||||
|
||||
match hmac.verify_slice(&verifier) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => return Err(Error::ServerError),
|
||||
Err(_) => Err(Error::ServerError),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -220,7 +220,7 @@ impl Message {
|
||||
/// Parse the server SASL challenge.
|
||||
fn parse(message: &BytesMut) -> Result<Message, Error> {
|
||||
let parts = String::from_utf8_lossy(&message[..])
|
||||
.split(",")
|
||||
.split(',')
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
@@ -268,7 +268,7 @@ mod test {
|
||||
#[test]
|
||||
fn parse_server_first_message() {
|
||||
let message = BytesMut::from(
|
||||
&"r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096".as_bytes()[..],
|
||||
"r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096".as_bytes(),
|
||||
);
|
||||
let message = Message::parse(&message).unwrap();
|
||||
assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
|
||||
@@ -279,7 +279,7 @@ mod test {
|
||||
#[test]
|
||||
fn parse_server_last_message() {
|
||||
let f = FinalMessage::parse(&BytesMut::from(
|
||||
&"v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw".as_bytes()[..],
|
||||
"v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw".as_bytes(),
|
||||
))
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
@@ -309,12 +309,12 @@ mod test {
|
||||
assert_eq!(std::str::from_utf8(&message).unwrap(), client_first);
|
||||
|
||||
let result = scram
|
||||
.update(&BytesMut::from(&server_first.as_bytes()[..]))
|
||||
.update(&BytesMut::from(server_first.as_bytes()))
|
||||
.unwrap();
|
||||
assert_eq!(std::str::from_utf8(&result).unwrap(), client_final);
|
||||
|
||||
scram
|
||||
.finish(&BytesMut::from(&server_final.as_bytes()[..]))
|
||||
.finish(&BytesMut::from(server_final.as_bytes()))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,7 +175,7 @@ impl Server {
|
||||
+ sasl_response.len() as i32, // length of SASL response
|
||||
);
|
||||
|
||||
res.put_slice(&format!("{}\0", SCRAM_SHA_256).as_bytes()[..]);
|
||||
res.put_slice(format!("{}\0", SCRAM_SHA_256).as_bytes());
|
||||
res.put_i32(sasl_response.len() as i32);
|
||||
res.put(sasl_response);
|
||||
|
||||
@@ -315,19 +315,19 @@ impl Server {
|
||||
let mut server = Server {
|
||||
address: address.clone(),
|
||||
read: BufReader::new(read),
|
||||
write: write,
|
||||
write,
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
server_info: server_info,
|
||||
server_id: server_id,
|
||||
process_id: process_id,
|
||||
secret_key: secret_key,
|
||||
server_info,
|
||||
server_id,
|
||||
process_id,
|
||||
secret_key,
|
||||
in_transaction: false,
|
||||
data_available: false,
|
||||
bad: false,
|
||||
needs_cleanup: false,
|
||||
client_server_map: client_server_map,
|
||||
client_server_map,
|
||||
connected_at: chrono::offset::Utc::now().naive_utc(),
|
||||
stats: stats,
|
||||
stats,
|
||||
application_name: String::new(),
|
||||
last_activity: SystemTime::now(),
|
||||
};
|
||||
@@ -371,7 +371,7 @@ impl Server {
|
||||
bytes.put_i32(process_id);
|
||||
bytes.put_i32(secret_key);
|
||||
|
||||
Ok(write_all(&mut stream, bytes).await?)
|
||||
write_all(&mut stream, bytes).await
|
||||
}
|
||||
|
||||
/// Send messages to the server from the client.
|
||||
@@ -616,7 +616,7 @@ impl Server {
|
||||
self.needs_cleanup = false;
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// A shorthand for `SET application_name = $1`.
|
||||
@@ -631,7 +631,7 @@ impl Server {
|
||||
.query(&format!("SET application_name = '{}'", name))
|
||||
.await?);
|
||||
self.needs_cleanup = needs_cleanup_before;
|
||||
return result;
|
||||
result
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -133,7 +133,7 @@ impl Sharder {
|
||||
#[inline]
|
||||
fn combine(mut a: u64, b: u64) -> u64 {
|
||||
a ^= b
|
||||
.wrapping_add(0x49a0f4dd15e5a8e3 as u64)
|
||||
.wrapping_add(0x49a0f4dd15e5a8e3_u64)
|
||||
.wrapping_add(a << 54)
|
||||
.wrapping_add(a >> 7);
|
||||
a
|
||||
@@ -141,7 +141,7 @@ impl Sharder {
|
||||
|
||||
#[inline]
|
||||
fn pg_u32_hash(k: u32) -> u64 {
|
||||
let mut a: u32 = 0x9e3779b9 as u32 + std::mem::size_of::<u32>() as u32 + 3923095 as u32;
|
||||
let mut a: u32 = 0x9e3779b9_u32 + std::mem::size_of::<u32>() as u32 + 3923095_u32;
|
||||
let mut b = a;
|
||||
let c = a;
|
||||
|
||||
|
||||
32
src/stats.rs
32
src/stats.rs
@@ -245,7 +245,7 @@ impl Default for Reporter {
|
||||
impl Reporter {
|
||||
/// Create a new Reporter instance.
|
||||
pub fn new(tx: Sender<Event>) -> Reporter {
|
||||
Reporter { tx: tx }
|
||||
Reporter { tx }
|
||||
}
|
||||
|
||||
/// Send statistics to the task keeping track of stats.
|
||||
@@ -338,9 +338,9 @@ impl Reporter {
|
||||
let event = Event {
|
||||
name: EventName::ClientRegistered {
|
||||
client_id,
|
||||
pool_name: pool_name.clone(),
|
||||
username: username.clone(),
|
||||
application_name: app_name.clone(),
|
||||
pool_name,
|
||||
username,
|
||||
application_name: app_name,
|
||||
},
|
||||
value: 1,
|
||||
};
|
||||
@@ -582,7 +582,7 @@ impl Collector {
|
||||
|
||||
let address_stats = address_stat_lookup
|
||||
.entry(server_info.address_id)
|
||||
.or_insert(HashMap::default());
|
||||
.or_insert_with(HashMap::default);
|
||||
let counter = address_stats
|
||||
.entry("total_query_count".to_string())
|
||||
.or_insert(0);
|
||||
@@ -618,7 +618,7 @@ impl Collector {
|
||||
|
||||
let address_stats = address_stat_lookup
|
||||
.entry(server_info.address_id)
|
||||
.or_insert(HashMap::default());
|
||||
.or_insert_with(HashMap::default);
|
||||
let counter = address_stats
|
||||
.entry("total_xact_count".to_string())
|
||||
.or_insert(0);
|
||||
@@ -636,7 +636,7 @@ impl Collector {
|
||||
|
||||
let address_stats = address_stat_lookup
|
||||
.entry(server_info.address_id)
|
||||
.or_insert(HashMap::default());
|
||||
.or_insert_with(HashMap::default);
|
||||
let counter =
|
||||
address_stats.entry("total_sent".to_string()).or_insert(0);
|
||||
*counter += stat.value;
|
||||
@@ -653,7 +653,7 @@ impl Collector {
|
||||
|
||||
let address_stats = address_stat_lookup
|
||||
.entry(server_info.address_id)
|
||||
.or_insert(HashMap::default());
|
||||
.or_insert_with(HashMap::default);
|
||||
let counter = address_stats
|
||||
.entry("total_received".to_string())
|
||||
.or_insert(0);
|
||||
@@ -683,7 +683,7 @@ impl Collector {
|
||||
|
||||
let address_stats = address_stat_lookup
|
||||
.entry(server_info.address_id)
|
||||
.or_insert(HashMap::default());
|
||||
.or_insert_with(HashMap::default);
|
||||
let counter = address_stats
|
||||
.entry("total_wait_time".to_string())
|
||||
.or_insert(0);
|
||||
@@ -694,7 +694,7 @@ impl Collector {
|
||||
server_info.pool_name.clone(),
|
||||
server_info.username.clone(),
|
||||
))
|
||||
.or_insert(HashMap::default());
|
||||
.or_insert_with(HashMap::default);
|
||||
|
||||
// We record max wait in microseconds, we do the pgbouncer second/microsecond split on admin
|
||||
let old_microseconds =
|
||||
@@ -750,7 +750,7 @@ impl Collector {
|
||||
// Update address aggregation stats
|
||||
let address_stats = address_stat_lookup
|
||||
.entry(address_id)
|
||||
.or_insert(HashMap::default());
|
||||
.or_insert_with(HashMap::default);
|
||||
let counter = address_stats.entry("total_errors".to_string()).or_insert(0);
|
||||
*counter += stat.value;
|
||||
}
|
||||
@@ -770,7 +770,7 @@ impl Collector {
|
||||
// Update address aggregation stats
|
||||
let address_stats = address_stat_lookup
|
||||
.entry(address_id)
|
||||
.or_insert(HashMap::default());
|
||||
.or_insert_with(HashMap::default);
|
||||
let counter = address_stats.entry("total_errors".to_string()).or_insert(0);
|
||||
*counter += stat.value;
|
||||
}
|
||||
@@ -891,7 +891,7 @@ impl Collector {
|
||||
} => {
|
||||
let pool_stats = pool_stat_lookup
|
||||
.entry((pool_name.clone(), username.clone()))
|
||||
.or_insert(HashMap::default());
|
||||
.or_insert_with(HashMap::default);
|
||||
|
||||
// These are re-calculated every iteration of the loop, so we don't want to add values
|
||||
// from the last iteration.
|
||||
@@ -964,17 +964,17 @@ impl Collector {
|
||||
// Clear maxwait after reporting
|
||||
pool_stat_lookup
|
||||
.entry((pool_name.clone(), username.clone()))
|
||||
.or_insert(HashMap::default())
|
||||
.or_insert_with(HashMap::default)
|
||||
.insert("maxwait_us".to_string(), 0);
|
||||
}
|
||||
|
||||
EventName::UpdateAverages { address_id } => {
|
||||
let stats = address_stat_lookup
|
||||
.entry(address_id)
|
||||
.or_insert(HashMap::default());
|
||||
.or_insert_with(HashMap::default);
|
||||
let old_stats = address_old_stat_lookup
|
||||
.entry(address_id)
|
||||
.or_insert(HashMap::default());
|
||||
.or_insert_with(HashMap::default);
|
||||
|
||||
// Calculate averages
|
||||
for stat in &[
|
||||
|
||||
@@ -30,12 +30,12 @@ impl Tls {
|
||||
pub fn new() -> Result<Self, Error> {
|
||||
let config = get_config();
|
||||
|
||||
let certs = match load_certs(&Path::new(&config.general.tls_certificate.unwrap())) {
|
||||
let certs = match load_certs(Path::new(&config.general.tls_certificate.unwrap())) {
|
||||
Ok(certs) => certs,
|
||||
Err(_) => return Err(Error::TlsError),
|
||||
};
|
||||
|
||||
let mut keys = match load_keys(&Path::new(&config.general.tls_private_key.unwrap())) {
|
||||
let mut keys = match load_keys(Path::new(&config.general.tls_private_key.unwrap())) {
|
||||
Ok(keys) => keys,
|
||||
Err(_) => return Err(Error::TlsError),
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user