mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
Zero-downtime password rotation
This commit is contained in:
@@ -58,9 +58,9 @@ tcp_keepalives_count = 5
|
||||
tcp_keepalives_interval = 5
|
||||
|
||||
# Path to TLS Certficate file to use for TLS connections
|
||||
# tls_certificate = "server.cert"
|
||||
tls_certificate = ".circleci/server.cert"
|
||||
# Path to TLS private key file to use for TLS connections
|
||||
# tls_private_key = "server.key"
|
||||
tls_private_key = ".circleci/server.key"
|
||||
|
||||
# User name to access the virtual administrative database (pgbouncer or pgcat)
|
||||
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
|
||||
@@ -129,6 +129,10 @@ connect_timeout = 3000
|
||||
username = "sharding_user"
|
||||
# Postgresql password
|
||||
password = "sharding_user"
|
||||
|
||||
# Passwords the client can use to connect. Useful for password rotations.
|
||||
secrets = [ "secret_one", "secret_two" ]
|
||||
|
||||
# Maximum number of server connections that can be established for this user
|
||||
# The maximum number of connection from a single Pgcat process to any database in the cluster
|
||||
# is the sum of pool_size across all users.
|
||||
|
||||
@@ -780,7 +780,7 @@ where
|
||||
let database = parts[0];
|
||||
let user = parts[1];
|
||||
|
||||
match get_pool(database, user) {
|
||||
match get_pool(database, user, None) {
|
||||
Some(pool) => {
|
||||
pool.pause();
|
||||
|
||||
@@ -827,7 +827,7 @@ where
|
||||
let database = parts[0];
|
||||
let user = parts[1];
|
||||
|
||||
match get_pool(database, user) {
|
||||
match get_pool(database, user, None) {
|
||||
Some(pool) => {
|
||||
pool.resume();
|
||||
|
||||
|
||||
382
src/auth.rs
Normal file
382
src/auth.rs
Normal file
@@ -0,0 +1,382 @@
|
||||
//! Module implementing various client authentication mechanisms.
|
||||
//!
|
||||
//! Currently supported: plain (via TLS), md5 (via TLS and plain text connection).
|
||||
|
||||
use crate::errors::Error;
|
||||
use crate::tokio::io::AsyncReadExt;
|
||||
use crate::{
|
||||
config::get_config,
|
||||
messages::{error_response, md5_hash_password, write_all, wrong_password, md5_hash_second_pass},
|
||||
pool::{get_pool, ConnectionPool},
|
||||
auth_passthrough::AuthPassthrough,
|
||||
};
|
||||
use bytes::{BufMut, BytesMut};
|
||||
use log::debug;
|
||||
|
||||
async fn refetch_auth_hash<S>(pool: &ConnectionPool, stream: &mut S, username: &str, pool_name: &str) -> Result<String, Error>
|
||||
where S: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send
|
||||
{
|
||||
let address = pool.address(0, 0);
|
||||
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
|
||||
let hash = apt.fetch_hash(address).await?;
|
||||
|
||||
return Ok(hash);
|
||||
}
|
||||
|
||||
error_response(
|
||||
stream,
|
||||
&format!(
|
||||
"No password set and auth passthrough failed for database: {}, user: {}",
|
||||
pool_name, username
|
||||
),
|
||||
).await?;
|
||||
|
||||
Err(Error::ClientError(format!(
|
||||
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
|
||||
address.username, address.database
|
||||
)))
|
||||
}
|
||||
|
||||
/// Read 'p' message from client.
|
||||
async fn response<R>(stream: &mut R) -> Result<Vec<u8>, Error>
|
||||
where
|
||||
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
|
||||
{
|
||||
let code = match stream.read_u8().await {
|
||||
Ok(code) => code,
|
||||
Err(_) => {
|
||||
return Err(Error::SocketError(
|
||||
"Error reading password code from client".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
if code as char != 'p' {
|
||||
return Err(Error::SocketError(format!("Expected p, got {}", code)));
|
||||
}
|
||||
|
||||
let len = match stream.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => {
|
||||
return Err(Error::SocketError(
|
||||
"Error reading password length from client".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let mut response = vec![0; (len - 4) as usize];
|
||||
|
||||
match stream.read_exact(&mut response).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::SocketError(
|
||||
"Error reading password from client".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Ok(response.to_vec())
|
||||
}
|
||||
|
||||
/// Make sure the pool we authenticated to has at least one server connection
|
||||
/// that can serve our request.
|
||||
async fn validate_pool<W>(
|
||||
stream: &mut W,
|
||||
mut pool: ConnectionPool,
|
||||
username: &str,
|
||||
pool_name: &str,
|
||||
) -> Result<(), Error>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
|
||||
{
|
||||
if !pool.validated() {
|
||||
match pool.validate().await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
error_response(
|
||||
stream,
|
||||
&format!(
|
||||
"Pool down for database: {:?}, user: {:?}",
|
||||
pool_name, username,
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Err(Error::ClientError(format!("Pool down: {:?}", err)))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear text authentication.
|
||||
///
|
||||
/// The client will send the password in plain text over the wire.
|
||||
/// To protect against obvious security issues, this is only used over TLS.
|
||||
///
|
||||
/// Clear text authentication is used to support zero-downtime password rotation.
|
||||
/// It allows the client to use multiple passwords when talking to the PgCat
|
||||
/// while the password is being rotated across multiple app instances.
|
||||
pub struct ClearText {
|
||||
username: String,
|
||||
pool_name: String,
|
||||
application_name: String,
|
||||
}
|
||||
|
||||
impl ClearText {
|
||||
/// Create a new ClearText authentication mechanism.
|
||||
pub fn new(username: &str, pool_name: &str, application_name: &str) -> ClearText {
|
||||
ClearText {
|
||||
username: username.to_string(),
|
||||
pool_name: pool_name.to_string(),
|
||||
application_name: application_name.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Issue 'R' clear text challenge to client.
|
||||
pub async fn challenge<W>(&self, stream: &mut W) -> Result<(), Error>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
|
||||
{
|
||||
debug!("Sending plain challenge");
|
||||
|
||||
let mut msg = BytesMut::new();
|
||||
msg.put_u8(b'R');
|
||||
msg.put_i32(8);
|
||||
msg.put_i32(3); // Clear text
|
||||
|
||||
write_all(stream, msg).await
|
||||
}
|
||||
|
||||
/// Authenticate client with server password or secret.
|
||||
pub async fn authenticate<R, W>(
|
||||
&self,
|
||||
read: &mut R,
|
||||
write: &mut W,
|
||||
) -> Result<Option<String>, Error>
|
||||
where
|
||||
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
|
||||
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
|
||||
{
|
||||
let response = response(read).await?;
|
||||
|
||||
let secret = String::from_utf8_lossy(&response[0..response.len() - 1]).to_string();
|
||||
|
||||
match get_pool(&self.pool_name, &self.username, Some(secret.clone())) {
|
||||
None => match get_pool(&self.pool_name, &self.username, None) {
|
||||
Some(pool) => {
|
||||
match pool.settings.user.password {
|
||||
Some(ref password) => {
|
||||
if password != &secret {
|
||||
wrong_password(write, &self.username).await?;
|
||||
Err(Error::ClientError(format!(
|
||||
"Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
|
||||
self.username, self.pool_name, self.application_name
|
||||
)))
|
||||
}
|
||||
else {
|
||||
validate_pool(write, pool, &self.username, &self.pool_name).await?;
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
None => {
|
||||
// Server is storing hashes, we can't query it for the plain text password.
|
||||
error_response(
|
||||
write,
|
||||
&format!(
|
||||
"No server password configured for database: {:?}, user: {:?}",
|
||||
self.pool_name, self.username
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Err(Error::ClientError(format!(
|
||||
"No server password configured for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
|
||||
self.username, self.pool_name, self.application_name
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None => {
|
||||
error_response(
|
||||
write,
|
||||
&format!(
|
||||
"No pool configured for database: {:?}, user: {:?}",
|
||||
self.pool_name, self.username
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Err(Error::ClientError(format!(
|
||||
"Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
|
||||
self.username, self.pool_name, self.application_name
|
||||
)))
|
||||
}
|
||||
},
|
||||
Some(pool) => {
|
||||
validate_pool(write, pool, &self.username, &self.pool_name).await?;
|
||||
Ok(Some(secret))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MD5 hash authentication.
|
||||
///
|
||||
/// Deprecated, but widely used everywhere, and currently required for poolers
|
||||
/// to authencticate clients without involving Postgres.
|
||||
///
|
||||
/// Admin clients are required to use MD5.
|
||||
pub struct Md5 {
|
||||
username: String,
|
||||
pool_name: String,
|
||||
application_name: String,
|
||||
salt: [u8; 4],
|
||||
admin: bool,
|
||||
}
|
||||
|
||||
impl Md5 {
|
||||
pub fn new(username: &str, pool_name: &str, application_name: &str, admin: bool) -> Md5 {
|
||||
let salt: [u8; 4] = [
|
||||
rand::random(),
|
||||
rand::random(),
|
||||
rand::random(),
|
||||
rand::random(),
|
||||
];
|
||||
|
||||
Md5 {
|
||||
username: username.to_string(),
|
||||
pool_name: pool_name.to_string(),
|
||||
application_name: application_name.to_string(),
|
||||
salt,
|
||||
admin,
|
||||
}
|
||||
}
|
||||
|
||||
/// Issue a 'R' MD5 challenge to the client.
|
||||
pub async fn challenge<W>(&self, stream: &mut W) -> Result<(), Error>
|
||||
where
|
||||
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
|
||||
{
|
||||
let mut res = BytesMut::new();
|
||||
res.put_u8(b'R');
|
||||
res.put_i32(12);
|
||||
res.put_i32(5); // MD5
|
||||
res.put_slice(&self.salt[..]);
|
||||
|
||||
write_all(stream, res).await
|
||||
}
|
||||
|
||||
/// Authenticate client with MD5. This is used for both admin and normal users.
|
||||
pub async fn authenticate<R, W>(&self, read: &mut R, write: &mut W) -> Result<(), Error>
|
||||
where
|
||||
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
|
||||
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
|
||||
{
|
||||
let password_hash = response(read).await?;
|
||||
|
||||
if self.admin {
|
||||
let config = get_config();
|
||||
|
||||
// Compare server and client hashes.
|
||||
let our_hash = md5_hash_password(
|
||||
&config.general.admin_username,
|
||||
&config.general.admin_password,
|
||||
&self.salt,
|
||||
);
|
||||
|
||||
if our_hash != password_hash {
|
||||
wrong_password(write, &self.username).await?;
|
||||
Err(Error::ClientError(format!(
|
||||
"Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
|
||||
self.username, self.pool_name, self.application_name
|
||||
)))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
} else {
|
||||
match get_pool(&self.pool_name, &self.username, None) {
|
||||
Some(pool) => {
|
||||
match &pool.settings.user.password {
|
||||
Some(ref password) => {
|
||||
let our_hash = md5_hash_password(&self.username, password, &self.salt);
|
||||
|
||||
if our_hash != password_hash {
|
||||
wrong_password(write, &self.username).await?;
|
||||
|
||||
Err(Error::ClientError(format!(
|
||||
"Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
|
||||
self.username, self.pool_name, self.application_name
|
||||
)))
|
||||
} else {
|
||||
validate_pool(write, pool, &self.username, &self.pool_name).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
None => {
|
||||
// Fetch hash from server
|
||||
let hash = (*pool.auth_hash.read()).clone();
|
||||
|
||||
let hash = match hash {
|
||||
Some(hash) => hash.to_string(),
|
||||
None => refetch_auth_hash(&pool, write, &self.username, &self.pool_name).await?,
|
||||
};
|
||||
|
||||
let our_hash = md5_hash_second_pass(&hash, &self.salt);
|
||||
|
||||
// Compare hashes
|
||||
if our_hash != password_hash {
|
||||
// Server hash maybe changed
|
||||
let hash = refetch_auth_hash(&pool, write, &self.username, &self.pool_name).await?;
|
||||
let our_hash = md5_hash_second_pass(&hash, &self.salt);
|
||||
|
||||
if our_hash != password_hash {
|
||||
wrong_password(write, &self.username).await?;
|
||||
|
||||
Err(Error::ClientError(format!(
|
||||
"Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
|
||||
self.username, self.pool_name, self.application_name
|
||||
)))
|
||||
} else {
|
||||
(*pool.auth_hash.write()) = Some(hash);
|
||||
|
||||
validate_pool(write, pool.clone(), &self.username, &self.pool_name).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
} else {
|
||||
wrong_password(write, &self.username).await?;
|
||||
|
||||
Err(Error::ClientError(format!(
|
||||
"Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
|
||||
self.username, self.pool_name, self.application_name
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None => {
|
||||
error_response(
|
||||
write,
|
||||
&format!(
|
||||
"No pool configured for database: {:?}, user: {:?}",
|
||||
self.pool_name, self.username
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
|
||||
return Err(Error::ClientError(format!(
|
||||
"Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}",
|
||||
self.username, self.pool_name, self.application_name
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -73,6 +73,7 @@ impl AuthPassthrough {
|
||||
password: Some(self.password.clone()),
|
||||
pool_size: 1,
|
||||
statement_timeout: 0,
|
||||
secrets: None,
|
||||
};
|
||||
|
||||
let user = &address.username;
|
||||
|
||||
197
src/client.rs
197
src/client.rs
@@ -2,7 +2,7 @@ use crate::errors::Error;
|
||||
use crate::pool::BanReason;
|
||||
/// Handle clients by pretending to be a PostgreSQL server.
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use log::{debug, error, info, trace, warn};
|
||||
use log::{debug, error, info, trace};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
@@ -90,6 +90,9 @@ pub struct Client<S, T> {
|
||||
/// Application name for this client (defaults to pgcat)
|
||||
application_name: String,
|
||||
|
||||
/// Which secret the user is using to connect, if any.
|
||||
secret: Option<String>,
|
||||
|
||||
/// Used to notify clients about an impending shutdown
|
||||
shutdown: Receiver<()>,
|
||||
}
|
||||
@@ -290,7 +293,7 @@ pub async fn client_entrypoint(
|
||||
/// Handle the first message the client sends.
|
||||
async fn get_startup<S>(stream: &mut S) -> Result<(ClientConnectionType, BytesMut), Error>
|
||||
where
|
||||
S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite,
|
||||
S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite + std::marker::Send,
|
||||
{
|
||||
// Get startup message length.
|
||||
let len = match stream.read_i32().await {
|
||||
@@ -377,24 +380,10 @@ pub async fn startup_tls(
|
||||
}
|
||||
}
|
||||
|
||||
async fn refetch_auth_hash(pool: &ConnectionPool) -> Result<String, Error> {
|
||||
let address = pool.address(0, 0);
|
||||
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
|
||||
let hash = apt.fetch_hash(address).await?;
|
||||
|
||||
return Ok(hash);
|
||||
}
|
||||
|
||||
Err(Error::ClientError(format!(
|
||||
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
|
||||
address.username, address.database
|
||||
)))
|
||||
}
|
||||
|
||||
impl<S, T> Client<S, T>
|
||||
where
|
||||
S: tokio::io::AsyncRead + std::marker::Unpin,
|
||||
T: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
S: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
|
||||
T: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
|
||||
{
|
||||
pub fn is_admin(&self) -> bool {
|
||||
self.admin
|
||||
@@ -457,161 +446,39 @@ where
|
||||
let process_id: i32 = rand::random();
|
||||
let secret_key: i32 = rand::random();
|
||||
|
||||
// Perform MD5 authentication.
|
||||
// TODO: Add SASL support.
|
||||
let salt = md5_challenge(&mut write).await?;
|
||||
let config = get_config();
|
||||
|
||||
let code = match read.read_u8().await {
|
||||
Ok(p) => p,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
|
||||
let secret = if admin {
|
||||
debug!("Using md5 auth for admin");
|
||||
let auth = crate::auth::Md5::new(&username, &pool_name, &application_name, true);
|
||||
auth.challenge(&mut write).await?;
|
||||
auth.authenticate(&mut read, &mut write).await?;
|
||||
None
|
||||
} else if !config.tls_enabled() {
|
||||
debug!("Using md5 auth");
|
||||
let auth = crate::auth::Md5::new(&username, &pool_name, &application_name, false);
|
||||
auth.challenge(&mut write).await?;
|
||||
auth.authenticate(&mut read, &mut write).await?;
|
||||
None
|
||||
} else {
|
||||
debug!("Using plain auth");
|
||||
let auth = crate::auth::ClearText::new(&username, &pool_name, &application_name);
|
||||
auth.challenge(&mut write).await?;
|
||||
auth.authenticate(&mut read, &mut write).await?
|
||||
};
|
||||
|
||||
// PasswordMessage
|
||||
if code as char != 'p' {
|
||||
return Err(Error::ProtocolSyncError(format!(
|
||||
"Expected p, got {}",
|
||||
code as char
|
||||
)));
|
||||
}
|
||||
|
||||
let len = match read.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
|
||||
};
|
||||
|
||||
let mut password_response = vec![0u8; (len - 4) as usize];
|
||||
|
||||
match read.read_exact(&mut password_response).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
|
||||
};
|
||||
|
||||
// Authenticate admin user.
|
||||
// Authenticated admin user.
|
||||
let (transaction_mode, server_info) = if admin {
|
||||
let config = get_config();
|
||||
// Compare server and client hashes.
|
||||
let password_hash = md5_hash_password(
|
||||
&config.general.admin_username,
|
||||
&config.general.admin_password,
|
||||
&salt,
|
||||
);
|
||||
|
||||
if password_hash != password_response {
|
||||
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name);
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
|
||||
}
|
||||
|
||||
(false, generate_server_info_for_admin())
|
||||
}
|
||||
// Authenticate normal user.
|
||||
// Authenticated normal user.
|
||||
else {
|
||||
let mut pool = match get_pool(pool_name, username) {
|
||||
Some(pool) => pool,
|
||||
None => {
|
||||
error_response(
|
||||
&mut write,
|
||||
&format!(
|
||||
"No pool configured for database: {:?}, user: {:?}",
|
||||
pool_name, username
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
|
||||
return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
|
||||
}
|
||||
};
|
||||
|
||||
// Obtain the hash to compare, we give preference to that written in cleartext in config
|
||||
// if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained
|
||||
// when the pool was created. If there is no hash there, we try to fetch it one more time.
|
||||
let password_hash = if let Some(password) = &pool.settings.user.password {
|
||||
Some(md5_hash_password(username, password, &salt))
|
||||
} else {
|
||||
if !get_config().is_auth_query_configured() {
|
||||
return Err(Error::ClientError(format!("Client auth not possible, no cleartext password set for username: {:?} in config and auth passthrough (query_auth) is not set up.", username)));
|
||||
}
|
||||
|
||||
let mut hash = (*pool.auth_hash.read()).clone();
|
||||
|
||||
if hash.is_none() {
|
||||
warn!("Query auth configured but no hash password found for pool {}. Will try to refetch it.", pool_name);
|
||||
match refetch_auth_hash(&pool).await {
|
||||
Ok(fetched_hash) => {
|
||||
warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, obtained. Updating.", username, pool_name, application_name);
|
||||
{
|
||||
let mut pool_auth_hash = pool.auth_hash.write();
|
||||
*pool_auth_hash = Some(fetched_hash.clone());
|
||||
}
|
||||
|
||||
hash = Some(fetched_hash);
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(
|
||||
Error::ClientError(
|
||||
format!("No cleartext password set, and no auth passthrough could not obtain the hash from server for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, the error was: {:?}",
|
||||
username,
|
||||
pool_name,
|
||||
application_name,
|
||||
err)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Some(md5_hash_second_pass(&hash.unwrap(), &salt))
|
||||
};
|
||||
|
||||
// Once we have the resulting hash, we compare with what the client gave us.
|
||||
// If they do not match and auth query is set up, we try to refetch the hash one more time
|
||||
// to see if the password has changed since the pool was created.
|
||||
//
|
||||
// @TODO: we could end up fetching again the same password twice (see above).
|
||||
if password_hash.unwrap() != password_response {
|
||||
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, will try to refetch it.", username, pool_name, application_name);
|
||||
let fetched_hash = refetch_auth_hash(&pool).await?;
|
||||
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
|
||||
|
||||
// Ok password changed in server an auth is possible.
|
||||
if new_password_hash == password_response {
|
||||
warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, changed in server. Updating.", username, pool_name, application_name);
|
||||
{
|
||||
let mut pool_auth_hash = pool.auth_hash.write();
|
||||
*pool_auth_hash = Some(fetched_hash);
|
||||
}
|
||||
} else {
|
||||
wrong_password(&mut write, username).await?;
|
||||
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
|
||||
}
|
||||
}
|
||||
|
||||
let pool = get_pool(&pool_name, &username, secret.clone()).unwrap();
|
||||
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
|
||||
|
||||
// If the pool hasn't been validated yet,
|
||||
// connect to the servers and figure out what's what.
|
||||
if !pool.validated() {
|
||||
match pool.validate().await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error_response(
|
||||
&mut write,
|
||||
&format!(
|
||||
"Pool down for database: {:?}, user: {:?}",
|
||||
pool_name, username
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
return Err(Error::ClientError(format!("Pool down: {:?}", err)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(transaction_mode, pool.server_info())
|
||||
};
|
||||
|
||||
debug!("Password authentication successful");
|
||||
debug!("Authentication successful");
|
||||
|
||||
auth_ok(&mut write).await?;
|
||||
write_all(&mut write, server_info).await?;
|
||||
@@ -619,7 +486,7 @@ where
|
||||
ready_for_query(&mut write).await?;
|
||||
|
||||
trace!("Startup OK");
|
||||
let pool_stats = match get_pool(pool_name, username) {
|
||||
let pool_stats = match get_pool(pool_name, username, secret.clone()) {
|
||||
Some(pool) => {
|
||||
if !admin {
|
||||
pool.stats
|
||||
@@ -659,6 +526,7 @@ where
|
||||
application_name: application_name.to_string(),
|
||||
shutdown,
|
||||
connected_to_server: false,
|
||||
secret,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -693,6 +561,7 @@ where
|
||||
application_name: String::from("undefined"),
|
||||
shutdown,
|
||||
connected_to_server: false,
|
||||
secret: None,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1200,7 +1069,7 @@ where
|
||||
/// Retrieve connection pool, if it exists.
|
||||
/// Return an error to the client otherwise.
|
||||
async fn get_pool(&mut self) -> Result<ConnectionPool, Error> {
|
||||
match get_pool(&self.pool_name, &self.username) {
|
||||
match get_pool(&self.pool_name, &self.username, self.secret.clone()) {
|
||||
Some(pool) => Ok(pool),
|
||||
None => {
|
||||
error_response(
|
||||
|
||||
@@ -181,6 +181,13 @@ pub struct User {
|
||||
pub pool_size: u32,
|
||||
#[serde(default)] // 0
|
||||
pub statement_timeout: u64,
|
||||
pub secrets: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl User {
|
||||
fn validate(&self) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for User {
|
||||
@@ -190,6 +197,7 @@ impl Default for User {
|
||||
password: None,
|
||||
pool_size: 15,
|
||||
statement_timeout: 0,
|
||||
secrets: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -508,6 +516,10 @@ impl Pool {
|
||||
None => None,
|
||||
};
|
||||
|
||||
for user in self.users.iter() {
|
||||
user.1.validate()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -657,6 +669,11 @@ impl Config {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks that we configured TLS.
|
||||
pub fn tls_enabled(&self) -> bool {
|
||||
self.general.tls_certificate.is_some() && self.general.tls_private_key.is_some()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
|
||||
@@ -62,6 +62,7 @@ use tokio::sync::broadcast;
|
||||
|
||||
mod admin;
|
||||
mod auth_passthrough;
|
||||
mod auth;
|
||||
mod client;
|
||||
mod config;
|
||||
mod constants;
|
||||
|
||||
@@ -46,29 +46,6 @@ where
|
||||
write_all(stream, auth_ok).await
|
||||
}
|
||||
|
||||
/// Generate md5 password challenge.
|
||||
pub async fn md5_challenge<S>(stream: &mut S) -> Result<[u8; 4], Error>
|
||||
where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
// let mut rng = rand::thread_rng();
|
||||
let salt: [u8; 4] = [
|
||||
rand::random(),
|
||||
rand::random(),
|
||||
rand::random(),
|
||||
rand::random(),
|
||||
];
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
res.put_u8(b'R');
|
||||
res.put_i32(12);
|
||||
res.put_i32(5); // MD5
|
||||
res.put_slice(&salt[..]);
|
||||
|
||||
write_all(stream, res).await?;
|
||||
Ok(salt)
|
||||
}
|
||||
|
||||
/// Give the client the process_id and secret we generated
|
||||
/// used in query cancellation.
|
||||
pub async fn backend_key_data<S>(
|
||||
|
||||
@@ -34,7 +34,7 @@ impl MirroredClient {
|
||||
None => (default, default, crate::config::Pool::default()),
|
||||
};
|
||||
|
||||
let identifier = PoolIdentifier::new(&self.database, &self.user.username);
|
||||
let identifier = PoolIdentifier::new(&self.database, &self.user.username, None);
|
||||
|
||||
let manager = ServerPool::new(
|
||||
self.address.clone(),
|
||||
|
||||
444
src/pool.rs
444
src/pool.rs
@@ -59,24 +59,22 @@ pub struct PoolIdentifier {
|
||||
|
||||
/// The username the client connects with. Each user gets its own pool.
|
||||
pub user: String,
|
||||
|
||||
/// The client secret (password).
|
||||
pub secret: Option<String>,
|
||||
}
|
||||
|
||||
impl PoolIdentifier {
|
||||
/// Create a new user/pool identifier.
|
||||
pub fn new(db: &str, user: &str) -> PoolIdentifier {
|
||||
pub fn new(db: &str, user: &str, secret: Option<String>) -> PoolIdentifier {
|
||||
PoolIdentifier {
|
||||
db: db.to_string(),
|
||||
user: user.to_string(),
|
||||
secret,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Address> for PoolIdentifier {
|
||||
fn from(address: &Address) -> PoolIdentifier {
|
||||
PoolIdentifier::new(&address.database, &address.username)
|
||||
}
|
||||
}
|
||||
|
||||
/// Pool settings.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct PoolSettings {
|
||||
@@ -210,224 +208,240 @@ impl ConnectionPool {
|
||||
|
||||
// There is one pool per database/user pair.
|
||||
for user in pool_config.users.values() {
|
||||
let old_pool_ref = get_pool(pool_name, &user.username);
|
||||
let identifier = PoolIdentifier::new(pool_name, &user.username);
|
||||
let mut secrets = match &user.secrets {
|
||||
Some(_) => user
|
||||
.secrets
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|secret| Some(secret.to_string()))
|
||||
.collect::<Vec<Option<String>>>(),
|
||||
None => vec![],
|
||||
};
|
||||
|
||||
secrets.push(None);
|
||||
|
||||
match old_pool_ref {
|
||||
Some(pool) => {
|
||||
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
|
||||
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
|
||||
if pool.config_hash == new_pool_hash_value {
|
||||
info!(
|
||||
"[pool: {}][user: {}] has not changed",
|
||||
pool_name, user.username
|
||||
);
|
||||
new_pools.insert(identifier.clone(), pool.clone());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
None => (),
|
||||
}
|
||||
for secret in secrets {
|
||||
|
||||
info!(
|
||||
"[pool: {}][user: {}] creating new pool",
|
||||
pool_name, user.username
|
||||
);
|
||||
let old_pool_ref = get_pool(pool_name, &user.username, secret.clone());
|
||||
let identifier = PoolIdentifier::new(pool_name, &user.username, secret.clone());
|
||||
|
||||
let mut shards = Vec::new();
|
||||
let mut addresses = Vec::new();
|
||||
let mut banlist = Vec::new();
|
||||
let mut shard_ids = pool_config
|
||||
.shards
|
||||
.clone()
|
||||
.into_keys()
|
||||
.collect::<Vec<String>>();
|
||||
let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone()));
|
||||
|
||||
// Allow the pool to be seen in statistics
|
||||
pool_stats.register(pool_stats.clone());
|
||||
|
||||
// Sort by shard number to ensure consistency.
|
||||
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
|
||||
let pool_auth_hash: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
|
||||
|
||||
for shard_idx in &shard_ids {
|
||||
let shard = &pool_config.shards[shard_idx];
|
||||
let mut pools = Vec::new();
|
||||
let mut servers = Vec::new();
|
||||
let mut replica_number = 0;
|
||||
|
||||
// Load Mirror settings
|
||||
for (address_index, server) in shard.servers.iter().enumerate() {
|
||||
let mut mirror_addresses = vec![];
|
||||
if let Some(mirror_settings_vec) = &shard.mirrors {
|
||||
for (mirror_idx, mirror_settings) in
|
||||
mirror_settings_vec.iter().enumerate()
|
||||
{
|
||||
if mirror_settings.mirroring_target_index != address_index {
|
||||
continue;
|
||||
}
|
||||
mirror_addresses.push(Address {
|
||||
id: address_id,
|
||||
database: shard.database.clone(),
|
||||
host: mirror_settings.host.clone(),
|
||||
port: mirror_settings.port,
|
||||
role: server.role,
|
||||
address_index: mirror_idx,
|
||||
replica_number,
|
||||
shard: shard_idx.parse::<usize>().unwrap(),
|
||||
username: user.username.clone(),
|
||||
pool_name: pool_name.clone(),
|
||||
mirrors: vec![],
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
});
|
||||
address_id += 1;
|
||||
match old_pool_ref {
|
||||
Some(pool) => {
|
||||
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
|
||||
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
|
||||
if pool.config_hash == new_pool_hash_value {
|
||||
info!(
|
||||
"[pool: {}][user: {}] has not changed",
|
||||
pool_name, user.username
|
||||
);
|
||||
new_pools.insert(identifier.clone(), pool.clone());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
let address = Address {
|
||||
id: address_id,
|
||||
database: shard.database.clone(),
|
||||
host: server.host.clone(),
|
||||
port: server.port,
|
||||
role: server.role,
|
||||
address_index,
|
||||
replica_number,
|
||||
shard: shard_idx.parse::<usize>().unwrap(),
|
||||
username: user.username.clone(),
|
||||
pool_name: pool_name.clone(),
|
||||
mirrors: mirror_addresses,
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
};
|
||||
|
||||
address_id += 1;
|
||||
|
||||
if server.role == Role::Replica {
|
||||
replica_number += 1;
|
||||
}
|
||||
|
||||
// We assume every server in the pool share user/passwords
|
||||
let auth_passthrough = AuthPassthrough::from_pool_config(pool_config);
|
||||
|
||||
if let Some(apt) = &auth_passthrough {
|
||||
match apt.fetch_hash(&address).await {
|
||||
Ok(ok) => {
|
||||
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) {
|
||||
if ok != *pool_auth_hash_value {
|
||||
warn!("Hash is not the same across shards of the same pool, client auth will \
|
||||
be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database);
|
||||
}
|
||||
}
|
||||
debug!("Hash obtained for {:?}", address);
|
||||
{
|
||||
let mut pool_auth_hash = pool_auth_hash.write();
|
||||
*pool_auth_hash = Some(ok.clone());
|
||||
}
|
||||
},
|
||||
Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err),
|
||||
}
|
||||
}
|
||||
|
||||
let manager = ServerPool::new(
|
||||
address.clone(),
|
||||
user.clone(),
|
||||
&shard.database,
|
||||
client_server_map.clone(),
|
||||
pool_stats.clone(),
|
||||
pool_auth_hash.clone(),
|
||||
);
|
||||
|
||||
let connect_timeout = match pool_config.connect_timeout {
|
||||
Some(connect_timeout) => connect_timeout,
|
||||
None => config.general.connect_timeout,
|
||||
};
|
||||
|
||||
let idle_timeout = match pool_config.idle_timeout {
|
||||
Some(idle_timeout) => idle_timeout,
|
||||
None => config.general.idle_timeout,
|
||||
};
|
||||
|
||||
let pool = Pool::builder()
|
||||
.max_size(user.pool_size)
|
||||
.connection_timeout(std::time::Duration::from_millis(connect_timeout))
|
||||
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
|
||||
.test_on_check_out(false)
|
||||
.build(manager)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
pools.push(pool);
|
||||
servers.push(address);
|
||||
None => (),
|
||||
}
|
||||
|
||||
shards.push(pools);
|
||||
addresses.push(servers);
|
||||
banlist.push(HashMap::new());
|
||||
}
|
||||
|
||||
assert_eq!(shards.len(), addresses.len());
|
||||
if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
|
||||
info!(
|
||||
"Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
|
||||
"[pool: {}][user: {}] creating new pool",
|
||||
pool_name, user.username
|
||||
);
|
||||
}
|
||||
|
||||
let pool = ConnectionPool {
|
||||
databases: shards,
|
||||
stats: pool_stats,
|
||||
addresses,
|
||||
banlist: Arc::new(RwLock::new(banlist)),
|
||||
config_hash: new_pool_hash_value,
|
||||
server_info: Arc::new(RwLock::new(BytesMut::new())),
|
||||
auth_hash: pool_auth_hash,
|
||||
settings: PoolSettings {
|
||||
pool_mode: pool_config.pool_mode,
|
||||
load_balancing_mode: pool_config.load_balancing_mode,
|
||||
// shards: pool_config.shards.clone(),
|
||||
shards: shard_ids.len(),
|
||||
user: user.clone(),
|
||||
default_role: match pool_config.default_role.as_str() {
|
||||
"any" => None,
|
||||
"replica" => Some(Role::Replica),
|
||||
"primary" => Some(Role::Primary),
|
||||
_ => unreachable!(),
|
||||
let mut shards = Vec::new();
|
||||
let mut addresses = Vec::new();
|
||||
let mut banlist = Vec::new();
|
||||
let mut shard_ids = pool_config
|
||||
.shards
|
||||
.clone()
|
||||
.into_keys()
|
||||
.collect::<Vec<String>>();
|
||||
let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone()));
|
||||
|
||||
// Allow the pool to be seen in statistics
|
||||
pool_stats.register(pool_stats.clone());
|
||||
|
||||
// Sort by shard number to ensure consistency.
|
||||
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
|
||||
let pool_auth_hash: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
|
||||
|
||||
for shard_idx in &shard_ids {
|
||||
let shard = &pool_config.shards[shard_idx];
|
||||
let mut pools = Vec::new();
|
||||
let mut servers = Vec::new();
|
||||
let mut replica_number = 0;
|
||||
|
||||
// Load Mirror settings
|
||||
for (address_index, server) in shard.servers.iter().enumerate() {
|
||||
let mut mirror_addresses = vec![];
|
||||
if let Some(mirror_settings_vec) = &shard.mirrors {
|
||||
for (mirror_idx, mirror_settings) in
|
||||
mirror_settings_vec.iter().enumerate()
|
||||
{
|
||||
if mirror_settings.mirroring_target_index != address_index {
|
||||
continue;
|
||||
}
|
||||
mirror_addresses.push(Address {
|
||||
id: address_id,
|
||||
database: shard.database.clone(),
|
||||
host: mirror_settings.host.clone(),
|
||||
port: mirror_settings.port,
|
||||
role: server.role,
|
||||
address_index: mirror_idx,
|
||||
replica_number,
|
||||
shard: shard_idx.parse::<usize>().unwrap(),
|
||||
username: user.username.clone(),
|
||||
pool_name: pool_name.clone(),
|
||||
mirrors: vec![],
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
});
|
||||
address_id += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let address = Address {
|
||||
id: address_id,
|
||||
database: shard.database.clone(),
|
||||
host: server.host.clone(),
|
||||
port: server.port,
|
||||
role: server.role,
|
||||
address_index,
|
||||
replica_number,
|
||||
shard: shard_idx.parse::<usize>().unwrap(),
|
||||
username: user.username.clone(),
|
||||
pool_name: pool_name.clone(),
|
||||
mirrors: mirror_addresses,
|
||||
stats: Arc::new(AddressStats::default()),
|
||||
};
|
||||
|
||||
address_id += 1;
|
||||
|
||||
if server.role == Role::Replica {
|
||||
replica_number += 1;
|
||||
}
|
||||
|
||||
// We assume every server in the pool share user/passwords
|
||||
let auth_passthrough = AuthPassthrough::from_pool_config(pool_config);
|
||||
|
||||
if let Some(apt) = &auth_passthrough {
|
||||
match apt.fetch_hash(&address).await {
|
||||
Ok(ok) => {
|
||||
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) {
|
||||
if ok != *pool_auth_hash_value {
|
||||
warn!("Hash is not the same across shards of the same pool, client auth will \
|
||||
be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database);
|
||||
}
|
||||
}
|
||||
debug!("Hash obtained for {:?}", address);
|
||||
{
|
||||
let mut pool_auth_hash = pool_auth_hash.write();
|
||||
*pool_auth_hash = Some(ok.clone());
|
||||
}
|
||||
},
|
||||
Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err),
|
||||
}
|
||||
}
|
||||
|
||||
let manager = ServerPool::new(
|
||||
address.clone(),
|
||||
user.clone(),
|
||||
&shard.database,
|
||||
client_server_map.clone(),
|
||||
pool_stats.clone(),
|
||||
pool_auth_hash.clone(),
|
||||
);
|
||||
|
||||
let connect_timeout = match pool_config.connect_timeout {
|
||||
Some(connect_timeout) => connect_timeout,
|
||||
None => config.general.connect_timeout,
|
||||
};
|
||||
|
||||
let idle_timeout = match pool_config.idle_timeout {
|
||||
Some(idle_timeout) => idle_timeout,
|
||||
None => config.general.idle_timeout,
|
||||
};
|
||||
|
||||
let pool = Pool::builder()
|
||||
.max_size(user.pool_size)
|
||||
.connection_timeout(std::time::Duration::from_millis(connect_timeout))
|
||||
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
|
||||
.test_on_check_out(false)
|
||||
.build(manager)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
pools.push(pool);
|
||||
servers.push(address);
|
||||
}
|
||||
|
||||
shards.push(pools);
|
||||
addresses.push(servers);
|
||||
banlist.push(HashMap::new());
|
||||
}
|
||||
|
||||
assert_eq!(shards.len(), addresses.len());
|
||||
if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
|
||||
info!(
|
||||
"Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
|
||||
pool_name, user.username
|
||||
);
|
||||
}
|
||||
|
||||
let pool = ConnectionPool {
|
||||
databases: shards,
|
||||
stats: pool_stats,
|
||||
addresses,
|
||||
banlist: Arc::new(RwLock::new(banlist)),
|
||||
config_hash: new_pool_hash_value,
|
||||
server_info: Arc::new(RwLock::new(BytesMut::new())),
|
||||
auth_hash: pool_auth_hash,
|
||||
settings: PoolSettings {
|
||||
pool_mode: pool_config.pool_mode,
|
||||
load_balancing_mode: pool_config.load_balancing_mode,
|
||||
// shards: pool_config.shards.clone(),
|
||||
shards: shard_ids.len(),
|
||||
user: user.clone(),
|
||||
default_role: match pool_config.default_role.as_str() {
|
||||
"any" => None,
|
||||
"replica" => Some(Role::Replica),
|
||||
"primary" => Some(Role::Primary),
|
||||
_ => unreachable!(),
|
||||
},
|
||||
query_parser_enabled: pool_config.query_parser_enabled,
|
||||
primary_reads_enabled: pool_config.primary_reads_enabled,
|
||||
sharding_function: pool_config.sharding_function,
|
||||
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
|
||||
healthcheck_delay: config.general.healthcheck_delay,
|
||||
healthcheck_timeout: config.general.healthcheck_timeout,
|
||||
ban_time: config.general.ban_time,
|
||||
sharding_key_regex: pool_config
|
||||
.sharding_key_regex
|
||||
.clone()
|
||||
.map(|regex| Regex::new(regex.as_str()).unwrap()),
|
||||
shard_id_regex: pool_config
|
||||
.shard_id_regex
|
||||
.clone()
|
||||
.map(|regex| Regex::new(regex.as_str()).unwrap()),
|
||||
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
|
||||
auth_query: pool_config.auth_query.clone(),
|
||||
auth_query_user: pool_config.auth_query_user.clone(),
|
||||
auth_query_password: pool_config.auth_query_password.clone(),
|
||||
},
|
||||
query_parser_enabled: pool_config.query_parser_enabled,
|
||||
primary_reads_enabled: pool_config.primary_reads_enabled,
|
||||
sharding_function: pool_config.sharding_function,
|
||||
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
|
||||
healthcheck_delay: config.general.healthcheck_delay,
|
||||
healthcheck_timeout: config.general.healthcheck_timeout,
|
||||
ban_time: config.general.ban_time,
|
||||
sharding_key_regex: pool_config
|
||||
.sharding_key_regex
|
||||
.clone()
|
||||
.map(|regex| Regex::new(regex.as_str()).unwrap()),
|
||||
shard_id_regex: pool_config
|
||||
.shard_id_regex
|
||||
.clone()
|
||||
.map(|regex| Regex::new(regex.as_str()).unwrap()),
|
||||
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
|
||||
auth_query: pool_config.auth_query.clone(),
|
||||
auth_query_user: pool_config.auth_query_user.clone(),
|
||||
auth_query_password: pool_config.auth_query_password.clone(),
|
||||
},
|
||||
validated: Arc::new(AtomicBool::new(false)),
|
||||
paused: Arc::new(AtomicBool::new(false)),
|
||||
paused_waiter: Arc::new(Notify::new()),
|
||||
};
|
||||
validated: Arc::new(AtomicBool::new(false)),
|
||||
paused: Arc::new(AtomicBool::new(false)),
|
||||
paused_waiter: Arc::new(Notify::new()),
|
||||
};
|
||||
|
||||
// Connect to the servers to make sure pool configuration is valid
|
||||
// before setting it globally.
|
||||
// Do this async and somewhere else, we don't have to wait here.
|
||||
let mut validate_pool = pool.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let _ = validate_pool.validate().await;
|
||||
});
|
||||
// Connect to the servers to make sure pool configuration is valid
|
||||
// before setting it globally.
|
||||
// Do this async and somewhere else, we don't have to wait here.
|
||||
let mut validate_pool = pool.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let _ = validate_pool.validate().await;
|
||||
});
|
||||
|
||||
// There is one pool per database/user pair.
|
||||
new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool);
|
||||
// There is one pool per database/user pair.
|
||||
new_pools.insert(PoolIdentifier::new(pool_name, &user.username, secret), pool);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -924,10 +938,10 @@ impl ManageConnection for ServerPool {
|
||||
}
|
||||
|
||||
/// Get the connection pool
|
||||
pub fn get_pool(db: &str, user: &str) -> Option<ConnectionPool> {
|
||||
(*(*POOLS.load()))
|
||||
.get(&PoolIdentifier::new(db, user))
|
||||
.cloned()
|
||||
pub fn get_pool(db: &str, user: &str, secret: Option<String>) -> Option<ConnectionPool> {
|
||||
let identifier = PoolIdentifier::new(db, user, secret);
|
||||
|
||||
(*(*POOLS.load())).get(&identifier).cloned()
|
||||
}
|
||||
|
||||
/// Get a pointer to all configured pools.
|
||||
|
||||
Reference in New Issue
Block a user