This commit is contained in:
Lev Kokotov
2023-03-30 14:16:50 -07:00
parent 345ee88342
commit fef737ea43
3 changed files with 40 additions and 16 deletions

View File

@@ -5,16 +5,24 @@
use crate::errors::Error; use crate::errors::Error;
use crate::tokio::io::AsyncReadExt; use crate::tokio::io::AsyncReadExt;
use crate::{ use crate::{
config::get_config,
messages::{error_response, md5_hash_password, write_all, wrong_password, md5_hash_second_pass},
pool::{get_pool, ConnectionPool},
auth_passthrough::AuthPassthrough, auth_passthrough::AuthPassthrough,
config::get_config,
messages::{
error_response, md5_hash_password, md5_hash_second_pass, write_all, wrong_password,
},
pool::{get_pool, ConnectionPool},
}; };
use bytes::{BufMut, BytesMut}; use bytes::{BufMut, BytesMut};
use log::debug; use log::debug;
async fn refetch_auth_hash<S>(pool: &ConnectionPool, stream: &mut S, username: &str, pool_name: &str) -> Result<String, Error> async fn refetch_auth_hash<S>(
where S: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send pool: &ConnectionPool,
stream: &mut S,
username: &str,
pool_name: &str,
) -> Result<String, Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{ {
let address = pool.address(0, 0); let address = pool.address(0, 0);
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) { if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
@@ -29,7 +37,8 @@ where S: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send
"No password set and auth passthrough failed for database: {}, user: {}", "No password set and auth passthrough failed for database: {}, user: {}",
pool_name, username pool_name, username
), ),
).await?; )
.await?;
Err(Error::ClientError(format!( Err(Error::ClientError(format!(
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.", "Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
@@ -174,8 +183,7 @@ impl ClearText {
"Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", "Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
self.username, self.pool_name, self.application_name self.username, self.pool_name, self.application_name
))) )))
} } else {
else {
validate_pool(write, pool, &self.username, &self.pool_name).await?; validate_pool(write, pool, &self.username, &self.pool_name).await?;
Ok(None) Ok(None)
@@ -305,7 +313,7 @@ impl Md5 {
Some(ref password) => { Some(ref password) => {
let our_hash = md5_hash_password(&self.username, password, &self.salt); let our_hash = md5_hash_password(&self.username, password, &self.salt);
if our_hash != password_hash { if our_hash != password_hash {
wrong_password(write, &self.username).await?; wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!( Err(Error::ClientError(format!(
@@ -324,7 +332,10 @@ impl Md5 {
let hash = match hash { let hash = match hash {
Some(hash) => hash.to_string(), Some(hash) => hash.to_string(),
None => refetch_auth_hash(&pool, write, &self.username, &self.pool_name).await?, None => {
refetch_auth_hash(&pool, write, &self.username, &self.pool_name)
.await?
}
}; };
let our_hash = md5_hash_second_pass(&hash, &self.salt); let our_hash = md5_hash_second_pass(&hash, &self.salt);
@@ -332,7 +343,13 @@ impl Md5 {
// Compare hashes // Compare hashes
if our_hash != password_hash { if our_hash != password_hash {
// Server hash maybe changed // Server hash maybe changed
let hash = refetch_auth_hash(&pool, write, &self.username, &self.pool_name).await?; let hash = refetch_auth_hash(
&pool,
write,
&self.username,
&self.pool_name,
)
.await?;
let our_hash = md5_hash_second_pass(&hash, &self.salt); let our_hash = md5_hash_second_pass(&hash, &self.salt);
if our_hash != password_hash { if our_hash != password_hash {
@@ -345,7 +362,13 @@ impl Md5 {
} else { } else {
(*pool.auth_hash.write()) = Some(hash); (*pool.auth_hash.write()) = Some(hash);
validate_pool(write, pool.clone(), &self.username, &self.pool_name).await?; validate_pool(
write,
pool.clone(),
&self.username,
&self.pool_name,
)
.await?;
Ok(()) Ok(())
} }

View File

@@ -61,8 +61,8 @@ use std::sync::Arc;
use tokio::sync::broadcast; use tokio::sync::broadcast;
mod admin; mod admin;
mod auth_passthrough;
mod auth; mod auth;
mod auth_passthrough;
mod client; mod client;
mod config; mod config;
mod constants; mod constants;

View File

@@ -222,7 +222,6 @@ impl ConnectionPool {
secrets.push(None); secrets.push(None);
for secret in secrets { for secret in secrets {
let old_pool_ref = get_pool(pool_name, &user.username, secret.clone()); let old_pool_ref = get_pool(pool_name, &user.username, secret.clone());
let identifier = PoolIdentifier::new(pool_name, &user.username, secret.clone()); let identifier = PoolIdentifier::new(pool_name, &user.username, secret.clone());
@@ -362,7 +361,9 @@ impl ConnectionPool {
let pool = Pool::builder() let pool = Pool::builder()
.max_size(user.pool_size) .max_size(user.pool_size)
.connection_timeout(std::time::Duration::from_millis(connect_timeout)) .connection_timeout(std::time::Duration::from_millis(
connect_timeout,
))
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout))) .idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
.test_on_check_out(false) .test_on_check_out(false)
.build(manager) .build(manager)