Compare commits

...

12 Commits

Author SHA1 Message Date
Lev Kokotov
3ca28a62c4 Dont accept empty passwords 2023-03-30 18:09:01 -07:00
Lev Kokotov
b65c1ddd56 readme 2023-03-30 17:36:49 -07:00
Lev Kokotov
ed31053cdb Fix spec 2023-03-30 17:35:32 -07:00
Lev Kokotov
4969abf355 Hmm 2023-03-30 15:29:10 -07:00
Lev Kokotov
112c0bdae8 Rebased 2023-03-30 15:19:52 -07:00
Lev Kokotov
fef737ea43 fmt 2023-03-30 14:16:50 -07:00
Lev Kokotov
345ee88342 Warn when secrets are too short 2023-03-30 14:16:38 -07:00
Lev Kokotov
db3d6c3baa Some tests 2023-03-30 14:16:36 -07:00
Lev Kokotov
197c32b4e8 Readme 2023-03-30 14:15:30 -07:00
Lev Kokotov
6345c39bd5 fix ci config 2023-03-30 14:15:07 -07:00
Lev Kokotov
32b913af94 update admin 2023-03-30 14:15:07 -07:00
Lev Kokotov
5c673b4333 Zero-downtime password rotation 2023-03-30 14:15:05 -07:00
20 changed files with 891 additions and 456 deletions

View File

@@ -25,7 +25,8 @@ PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load bal
| Sharding using comments parsing/Regex | **Experimental** | Clients can include shard information (sharding key, shard ID) in the query comments. | | Sharding using comments parsing/Regex | **Experimental** | Clients can include shard information (sharding key, shard ID) in the query comments. |
| Automatic sharding | **Experimental** | PgCat can parse queries, detect sharding keys automatically, and route queries to the correct shard. | | Automatic sharding | **Experimental** | PgCat can parse queries, detect sharding keys automatically, and route queries to the correct shard. |
| Mirroring | **Experimental** | Mirror queries between multiple databases in order to test servers with realistic production traffic. | | Mirroring | **Experimental** | Mirror queries between multiple databases in order to test servers with realistic production traffic. |
| Auth passthrough | **Experimental** | MD5 password authentication can be configured to use an `auth_query` so no cleartext passwords are needed in the config file. | | Auth passthrough | **Experimental** | MD5 password authentication can be configured to use an `auth_query` so no cleartext passwords are needed in the config file. |
| Password rotation | **Experimental** | Allows to rotate passwords without downtime or using third-party tools to manage Postgres authentication. |
## Status ## Status
@@ -244,6 +245,12 @@ The config can be reloaded by sending a `kill -s SIGHUP` to the process or by qu
Mirroring allows to route queries to multiple databases at the same time. This is useful for prewarning replicas before placing them into the active configuration, or for testing different versions of Postgres with live traffic. Mirroring allows to route queries to multiple databases at the same time. This is useful for prewarning replicas before placing them into the active configuration, or for testing different versions of Postgres with live traffic.
### Password rotation
Password rotation allows to specify multiple passwords for a user, so they can connect to PgCat with multiple credentials. This allows distributed applications to change their configuration (connection strings) gradually and for PgCat to monitor their progression in admin statistics. Once the new secret is deployed everywhere, the old one can be removed from PgCat.
This also decouples server passwords from client passwords, allowing to change one without necessarily changing the other.
## License ## License
PgCat is free and open source, released under the MIT license. PgCat is free and open source, released under the MIT license.

View File

@@ -64,7 +64,7 @@ services:
<<: *common-env-pg <<: *common-env-pg
POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5 POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5
PGPORT: 10432 PGPORT: 10432
command: ["postgres", "-p", "5432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"] command: ["postgres", "-p", "10432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"]
toxiproxy: toxiproxy:
build: . build: .

View File

@@ -58,9 +58,9 @@ tcp_keepalives_count = 5
tcp_keepalives_interval = 5 tcp_keepalives_interval = 5
# Path to TLS Certficate file to use for TLS connections # Path to TLS Certficate file to use for TLS connections
# tls_certificate = "server.cert" # tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections # Path to TLS private key file to use for TLS connections
# tls_private_key = "server.key" # tls_private_key = ".circleci/server.key"
# User name to access the virtual administrative database (pgbouncer or pgcat) # User name to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc.. # Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
@@ -122,6 +122,10 @@ idle_timeout = 40000
# Connect timeout can be overwritten in the pool # Connect timeout can be overwritten in the pool
connect_timeout = 3000 connect_timeout = 3000
# auth_query = "SELECT * FROM public.user_lookup('$1')"
# auth_query_user = "postgres"
# auth_query_password = "postgres"
# User configs are structured as pool.<pool_name>.users.<user_index> # User configs are structured as pool.<pool_name>.users.<user_index>
# This secion holds the credentials for users that may connect to this cluster # This secion holds the credentials for users that may connect to this cluster
[pools.sharded_db.users.0] [pools.sharded_db.users.0]
@@ -129,6 +133,10 @@ connect_timeout = 3000
username = "sharding_user" username = "sharding_user"
# Postgresql password # Postgresql password
password = "sharding_user" password = "sharding_user"
# # Passwords the client can use to connect. Useful for password rotations.
# secrets = [ "secret_one", "secret_two" ]
# Maximum number of server connections that can be established for this user # Maximum number of server connections that can be established for this user
# The maximum number of connection from a single Pgcat process to any database in the cluster # The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users. # is the sum of pool_size across all users.

View File

@@ -259,6 +259,7 @@ where
let columns = vec![ let columns = vec![
("database", DataType::Text), ("database", DataType::Text),
("user", DataType::Text), ("user", DataType::Text),
("secret", DataType::Text),
("pool_mode", DataType::Text), ("pool_mode", DataType::Text),
("cl_idle", DataType::Numeric), ("cl_idle", DataType::Numeric),
("cl_active", DataType::Numeric), ("cl_active", DataType::Numeric),
@@ -276,10 +277,11 @@ where
let mut res = BytesMut::new(); let mut res = BytesMut::new();
res.put(row_description(&columns)); res.put(row_description(&columns));
for ((_user_pool, _pool), pool_stats) in all_pool_stats { for (_, pool_stats) in all_pool_stats {
let mut row = vec![ let mut row = vec![
pool_stats.database(), pool_stats.database(),
pool_stats.user(), pool_stats.user(),
pool_stats.redacted_secret(),
pool_stats.pool_mode().to_string(), pool_stats.pool_mode().to_string(),
]; ];
pool_stats.populate_row(&mut row); pool_stats.populate_row(&mut row);
@@ -780,7 +782,7 @@ where
let database = parts[0]; let database = parts[0];
let user = parts[1]; let user = parts[1];
match get_pool(database, user) { match get_pool(database, user, None) {
Some(pool) => { Some(pool) => {
pool.pause(); pool.pause();
@@ -827,7 +829,7 @@ where
let database = parts[0]; let database = parts[0];
let user = parts[1]; let user = parts[1];
match get_pool(database, user) { match get_pool(database, user, None) {
Some(pool) => { Some(pool) => {
pool.resume(); pool.resume();
@@ -895,13 +897,20 @@ where
res.put(row_description(&vec![ res.put(row_description(&vec![
("name", DataType::Text), ("name", DataType::Text),
("pool_mode", DataType::Text), ("pool_mode", DataType::Text),
("secret", DataType::Text),
])); ]));
for (user_pool, pool) in get_all_pools() { for (user_pool, pool) in get_all_pools() {
let pool_config = &pool.settings; let pool_config = &pool.settings;
let redacted_secret = match user_pool.secret {
Some(secret) => format!("****{}", &secret[secret.len() - 4..]),
None => "<no secret>".to_string(),
};
res.put(data_row(&vec![ res.put(data_row(&vec![
user_pool.user.clone(), user_pool.user.clone(),
pool_config.pool_mode.to_string(), pool_config.pool_mode.to_string(),
redacted_secret,
])); ]));
} }

452
src/auth.rs Normal file
View File

@@ -0,0 +1,452 @@
//! Module implementing various client authentication mechanisms.
//!
//! Currently supported: plain (via TLS), md5 (via TLS and plain text connection).
use crate::errors::Error;
use crate::tokio::io::AsyncReadExt;
use crate::{
auth_passthrough::AuthPassthrough,
config::get_config,
messages::{
error_response, md5_hash_password, md5_hash_second_pass, write_all, wrong_password,
},
pool::{get_pool, ConnectionPool},
};
use bytes::{BufMut, BytesMut};
use log::debug;
async fn refetch_auth_hash<S>(
pool: &ConnectionPool,
stream: &mut S,
username: &str,
pool_name: &str,
) -> Result<String, Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let config = get_config();
debug!("Fetching auth hash");
if config.is_auth_query_configured() {
let address = pool.address(0, 0);
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
let hash = apt.fetch_hash(address).await?;
debug!("Auth query succeeded");
return Ok(hash);
}
} else {
debug!("Auth query not configured on pool");
}
error_response(
stream,
&format!(
"No password set and auth passthrough failed for database: {}, user: {}",
pool_name, username
),
)
.await?;
Err(Error::ClientError(format!(
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
pool_name, username
)))
}
/// Read 'p' message from client.
async fn response<R>(stream: &mut R) -> Result<Vec<u8>, Error>
where
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
{
let code = match stream.read_u8().await {
Ok(code) => code,
Err(_) => {
return Err(Error::SocketError(
"Error reading password code from client".to_string(),
))
}
};
if code as char != 'p' {
return Err(Error::SocketError(format!("Expected p, got {}", code)));
}
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => {
return Err(Error::SocketError(
"Error reading password length from client".to_string(),
))
}
};
let mut response = vec![0; (len - 4) as usize];
// Too short to be a password (null-terminated)
if response.len() < 2 {
return Err(Error::ClientError(format!("Password response too short")));
}
match stream.read_exact(&mut response).await {
Ok(_) => (),
Err(_) => {
return Err(Error::SocketError(
"Error reading password from client".to_string(),
))
}
};
Ok(response.to_vec())
}
/// Make sure the pool we authenticated to has at least one server connection
/// that can serve our request.
async fn validate_pool<W>(
stream: &mut W,
mut pool: ConnectionPool,
username: &str,
pool_name: &str,
) -> Result<(), Error>
where
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
if !pool.validated() {
match pool.validate().await {
Ok(_) => Ok(()),
Err(err) => {
error_response(
stream,
&format!("Pool down for database: {}, user: {}", pool_name, username,),
)
.await?;
Err(Error::ClientError(format!("Pool down: {:?}", err)))
}
}
} else {
Ok(())
}
}
/// Clear text authentication.
///
/// The client will send the password in plain text over the wire.
/// To protect against obvious security issues, this is only used over TLS.
///
/// Clear text authentication is used to support zero-downtime password rotation.
/// It allows the client to use multiple passwords when talking to the PgCat
/// while the password is being rotated across multiple app instances.
pub struct ClearText {
username: String,
pool_name: String,
application_name: String,
}
impl ClearText {
/// Create a new ClearText authentication mechanism.
pub fn new(username: &str, pool_name: &str, application_name: &str) -> ClearText {
ClearText {
username: username.to_string(),
pool_name: pool_name.to_string(),
application_name: application_name.to_string(),
}
}
/// Issue 'R' clear text challenge to client.
pub async fn challenge<W>(&self, stream: &mut W) -> Result<(), Error>
where
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
debug!("Sending plain challenge");
let mut msg = BytesMut::new();
msg.put_u8(b'R');
msg.put_i32(8);
msg.put_i32(3); // Clear text
write_all(stream, msg).await
}
/// Authenticate client with server password or secret.
pub async fn authenticate<R, W>(
&self,
read: &mut R,
write: &mut W,
) -> Result<Option<String>, Error>
where
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let response = response(read).await?;
let secret = String::from_utf8_lossy(&response[0..response.len() - 1]).to_string();
match get_pool(&self.pool_name, &self.username, Some(secret.clone())) {
None => match get_pool(&self.pool_name, &self.username, None) {
Some(pool) => {
match pool.settings.user.password {
Some(ref password) => {
if password != &secret {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
} else {
validate_pool(write, pool, &self.username, &self.pool_name).await?;
Ok(None)
}
}
None => {
// Server is storing hashes, we can't query it for the plain text password.
error_response(
write,
&format!(
"No server password configured for database: {}, user: {}",
self.pool_name, self.username
),
)
.await?;
Err(Error::ClientError(format!(
"No server password configured for {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
}
}
}
None => {
error_response(
write,
&format!(
"No pool configured for database: {}, user: {}",
self.pool_name, self.username
),
)
.await?;
Err(Error::ClientError(format!(
"Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
}
},
Some(pool) => {
validate_pool(write, pool, &self.username, &self.pool_name).await?;
Ok(Some(secret))
}
}
}
}
/// MD5 hash authentication.
///
/// Deprecated, but widely used everywhere, and currently required for poolers
/// to authencticate clients without involving Postgres.
///
/// Admin clients are required to use MD5.
pub struct Md5 {
username: String,
pool_name: String,
application_name: String,
salt: [u8; 4],
admin: bool,
}
impl Md5 {
pub fn new(username: &str, pool_name: &str, application_name: &str, admin: bool) -> Md5 {
let salt: [u8; 4] = [
rand::random(),
rand::random(),
rand::random(),
rand::random(),
];
Md5 {
username: username.to_string(),
pool_name: pool_name.to_string(),
application_name: application_name.to_string(),
salt,
admin,
}
}
/// Issue a 'R' MD5 challenge to the client.
pub async fn challenge<W>(&self, stream: &mut W) -> Result<(), Error>
where
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let mut res = BytesMut::new();
res.put_u8(b'R');
res.put_i32(12);
res.put_i32(5); // MD5
res.put_slice(&self.salt[..]);
write_all(stream, res).await
}
/// Authenticate client with MD5. This is used for both admin and normal users.
pub async fn authenticate<R, W>(&self, read: &mut R, write: &mut W) -> Result<(), Error>
where
R: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
W: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{
let password_hash = response(read).await?;
if self.admin {
let config = get_config();
// Compare server and client hashes.
let our_hash = md5_hash_password(
&config.general.admin_username,
&config.general.admin_password,
&self.salt,
);
if our_hash != password_hash {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
} else {
Ok(())
}
} else {
match get_pool(&self.pool_name, &self.username, None) {
Some(pool) => {
match &pool.settings.user.password {
Some(ref password) => {
let our_hash = md5_hash_password(&self.username, password, &self.salt);
if our_hash != password_hash {
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
} else {
validate_pool(write, pool, &self.username, &self.pool_name).await?;
Ok(())
}
}
None => {
if !get_config().is_auth_query_configured() {
error_response(
write,
&format!(
"No password configured and auth_query is not set: {}, user: {}",
self.pool_name, self.username
),
)
.await?;
return Err(Error::ClientError(format!(
"No password configured and auth_query is not set"
)));
}
debug!("Using auth_query");
// Fetch hash from server
let hash = (*pool.auth_hash.read()).clone();
let hash = match hash {
Some(hash) => {
debug!("Using existing hash: {}", hash);
hash.clone()
}
None => {
debug!("Pool has no hash set, fetching new one");
let hash = refetch_auth_hash(
&pool,
write,
&self.username,
&self.pool_name,
)
.await?;
(*pool.auth_hash.write()) = Some(hash.clone());
hash
}
};
let our_hash = md5_hash_second_pass(&hash, &self.salt);
// Compare hashes
if our_hash != password_hash {
debug!("Pool auth query hash did not match, refetching");
// Server hash maybe changed
let hash = refetch_auth_hash(
&pool,
write,
&self.username,
&self.pool_name,
)
.await?;
let our_hash = md5_hash_second_pass(&hash, &self.salt);
if our_hash != password_hash {
debug!("Auth query failed, passwords don't match");
wrong_password(write, &self.username).await?;
Err(Error::ClientError(format!(
"Invalid password {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)))
} else {
(*pool.auth_hash.write()) = Some(hash);
validate_pool(
write,
pool.clone(),
&self.username,
&self.pool_name,
)
.await?;
Ok(())
}
} else {
validate_pool(write, pool.clone(), &self.username, &self.pool_name)
.await?;
Ok(())
}
}
}
}
None => {
error_response(
write,
&format!(
"No pool configured for database: {}, user: {}",
self.pool_name, self.username
),
)
.await?;
return Err(Error::ClientError(format!(
"Invalid pool name {{ username: {}, pool_name: {}, application_name: {} }}",
self.username, self.pool_name, self.application_name
)));
}
}
}
}
}

View File

@@ -73,23 +73,25 @@ impl AuthPassthrough {
password: Some(self.password.clone()), password: Some(self.password.clone()),
pool_size: 1, pool_size: 1,
statement_timeout: 0, statement_timeout: 0,
secrets: None,
}; };
let user = &address.username; let user = &address.username;
debug!("Connecting to server to obtain auth hashes."); debug!("Connecting to server to obtain auth hashes.");
let auth_query = self.query.replace("$1", user); let auth_query = self.query.replace("$1", user);
match Server::exec_simple_query(address, &auth_user, &auth_query).await { match Server::exec_simple_query(address, &auth_user, &auth_query).await {
Ok(password_data) => { Ok(password_data) => {
if password_data.len() == 2 && password_data.first().unwrap() == user { if password_data.len() == 2 && password_data.first().unwrap() == user {
if let Some(stripped_hash) = password_data.last().unwrap().to_string().strip_prefix("md5") { if let Some(stripped_hash) = password_data.last().unwrap().to_string().strip_prefix("md5") {
Ok(stripped_hash.to_string()) Ok(stripped_hash.to_string())
} } else {
else { Err(Error::AuthPassthroughError(
Err(Error::AuthPassthroughError( "Obtained hash from auth_query does not seem to be in md5 format.".to_string(),
"Obtained hash from auth_query does not seem to be in md5 format.".to_string(), ))
)) }
}
} else { } else {
Err(Error::AuthPassthroughError( Err(Error::AuthPassthroughError(
"Data obtained from query does not follow the scheme 'user','hash'." "Data obtained from query does not follow the scheme 'user','hash'."
@@ -97,11 +99,12 @@ impl AuthPassthrough {
)) ))
} }
} }
Err(err) => { Err(err) => {
Err(Error::AuthPassthroughError( Err(Error::AuthPassthroughError(
format!("Error trying to obtain password from auth_query, ignoring hash for user '{}'. Error: {:?}", format!("Error trying to obtain password from auth_query, ignoring hash for user '{}'. Error: {:?}",
user, err))) user, err)))
}
} }
}
} }
} }

View File

@@ -2,7 +2,7 @@ use crate::errors::Error;
use crate::pool::BanReason; use crate::pool::BanReason;
/// Handle clients by pretending to be a PostgreSQL server. /// Handle clients by pretending to be a PostgreSQL server.
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use log::{debug, error, info, trace, warn}; use log::{debug, error, info, trace};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
@@ -12,7 +12,6 @@ use tokio::sync::broadcast::Receiver;
use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::Sender;
use crate::admin::{generate_server_info_for_admin, handle_admin}; use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::auth_passthrough::AuthPassthrough;
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode}; use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
use crate::constants::*; use crate::constants::*;
use crate::messages::*; use crate::messages::*;
@@ -90,6 +89,9 @@ pub struct Client<S, T> {
/// Application name for this client (defaults to pgcat) /// Application name for this client (defaults to pgcat)
application_name: String, application_name: String,
/// Which secret the user is using to connect, if any.
secret: Option<String>,
/// Used to notify clients about an impending shutdown /// Used to notify clients about an impending shutdown
shutdown: Receiver<()>, shutdown: Receiver<()>,
} }
@@ -290,7 +292,7 @@ pub async fn client_entrypoint(
/// Handle the first message the client sends. /// Handle the first message the client sends.
async fn get_startup<S>(stream: &mut S) -> Result<(ClientConnectionType, BytesMut), Error> async fn get_startup<S>(stream: &mut S) -> Result<(ClientConnectionType, BytesMut), Error>
where where
S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite, S: tokio::io::AsyncRead + std::marker::Unpin + tokio::io::AsyncWrite + std::marker::Send,
{ {
// Get startup message length. // Get startup message length.
let len = match stream.read_i32().await { let len = match stream.read_i32().await {
@@ -377,24 +379,10 @@ pub async fn startup_tls(
} }
} }
async fn refetch_auth_hash(pool: &ConnectionPool) -> Result<String, Error> {
let address = pool.address(0, 0);
if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) {
let hash = apt.fetch_hash(address).await?;
return Ok(hash);
}
Err(Error::ClientError(format!(
"Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.",
address.username, address.database
)))
}
impl<S, T> Client<S, T> impl<S, T> Client<S, T>
where where
S: tokio::io::AsyncRead + std::marker::Unpin, S: tokio::io::AsyncRead + std::marker::Unpin + std::marker::Send,
T: tokio::io::AsyncWrite + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin + std::marker::Send,
{ {
pub fn is_admin(&self) -> bool { pub fn is_admin(&self) -> bool {
self.admin self.admin
@@ -457,161 +445,39 @@ where
let process_id: i32 = rand::random(); let process_id: i32 = rand::random();
let secret_key: i32 = rand::random(); let secret_key: i32 = rand::random();
// Perform MD5 authentication. let config = get_config();
// TODO: Add SASL support.
let salt = md5_challenge(&mut write).await?;
let code = match read.read_u8().await { let secret = if admin {
Ok(p) => p, debug!("Using md5 auth for admin");
Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))), let auth = crate::auth::Md5::new(&username, &pool_name, &application_name, true);
auth.challenge(&mut write).await?;
auth.authenticate(&mut read, &mut write).await?;
None
} else if !config.tls_enabled() {
debug!("Using md5 auth");
let auth = crate::auth::Md5::new(&username, &pool_name, &application_name, false);
auth.challenge(&mut write).await?;
auth.authenticate(&mut read, &mut write).await?;
None
} else {
debug!("Using plain auth");
let auth = crate::auth::ClearText::new(&username, &pool_name, &application_name);
auth.challenge(&mut write).await?;
auth.authenticate(&mut read, &mut write).await?
}; };
// PasswordMessage // Authenticated admin user.
if code as char != 'p' {
return Err(Error::ProtocolSyncError(format!(
"Expected p, got {}",
code as char
)));
}
let len = match read.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
};
let mut password_response = vec![0u8; (len - 4) as usize];
match read.read_exact(&mut password_response).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))),
};
// Authenticate admin user.
let (transaction_mode, server_info) = if admin { let (transaction_mode, server_info) = if admin {
let config = get_config();
// Compare server and client hashes.
let password_hash = md5_hash_password(
&config.general.admin_username,
&config.general.admin_password,
&salt,
);
if password_hash != password_response {
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name);
wrong_password(&mut write, username).await?;
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
}
(false, generate_server_info_for_admin()) (false, generate_server_info_for_admin())
} }
// Authenticate normal user. // Authenticated normal user.
else { else {
let mut pool = match get_pool(pool_name, username) { let pool = get_pool(&pool_name, &username, secret.clone()).unwrap();
Some(pool) => pool,
None => {
error_response(
&mut write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
pool_name, username
),
)
.await?;
return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
}
};
// Obtain the hash to compare, we give preference to that written in cleartext in config
// if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained
// when the pool was created. If there is no hash there, we try to fetch it one more time.
let password_hash = if let Some(password) = &pool.settings.user.password {
Some(md5_hash_password(username, password, &salt))
} else {
if !get_config().is_auth_query_configured() {
return Err(Error::ClientError(format!("Client auth not possible, no cleartext password set for username: {:?} in config and auth passthrough (query_auth) is not set up.", username)));
}
let mut hash = (*pool.auth_hash.read()).clone();
if hash.is_none() {
warn!("Query auth configured but no hash password found for pool {}. Will try to refetch it.", pool_name);
match refetch_auth_hash(&pool).await {
Ok(fetched_hash) => {
warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, obtained. Updating.", username, pool_name, application_name);
{
let mut pool_auth_hash = pool.auth_hash.write();
*pool_auth_hash = Some(fetched_hash.clone());
}
hash = Some(fetched_hash);
}
Err(err) => {
return Err(
Error::ClientError(
format!("No cleartext password set, and no auth passthrough could not obtain the hash from server for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, the error was: {:?}",
username,
pool_name,
application_name,
err)
)
);
}
}
};
Some(md5_hash_second_pass(&hash.unwrap(), &salt))
};
// Once we have the resulting hash, we compare with what the client gave us.
// If they do not match and auth query is set up, we try to refetch the hash one more time
// to see if the password has changed since the pool was created.
//
// @TODO: we could end up fetching again the same password twice (see above).
if password_hash.unwrap() != password_response {
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, will try to refetch it.", username, pool_name, application_name);
let fetched_hash = refetch_auth_hash(&pool).await?;
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
// Ok password changed in server an auth is possible.
if new_password_hash == password_response {
warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, changed in server. Updating.", username, pool_name, application_name);
{
let mut pool_auth_hash = pool.auth_hash.write();
*pool_auth_hash = Some(fetched_hash);
}
} else {
wrong_password(&mut write, username).await?;
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name)));
}
}
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction; let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
// If the pool hasn't been validated yet,
// connect to the servers and figure out what's what.
if !pool.validated() {
match pool.validate().await {
Ok(_) => (),
Err(err) => {
error_response(
&mut write,
&format!(
"Pool down for database: {:?}, user: {:?}",
pool_name, username
),
)
.await?;
return Err(Error::ClientError(format!("Pool down: {:?}", err)));
}
}
}
(transaction_mode, pool.server_info()) (transaction_mode, pool.server_info())
}; };
debug!("Password authentication successful"); debug!("Authentication successful");
auth_ok(&mut write).await?; auth_ok(&mut write).await?;
write_all(&mut write, server_info).await?; write_all(&mut write, server_info).await?;
@@ -619,7 +485,7 @@ where
ready_for_query(&mut write).await?; ready_for_query(&mut write).await?;
trace!("Startup OK"); trace!("Startup OK");
let pool_stats = match get_pool(pool_name, username) { let pool_stats = match get_pool(pool_name, username, secret.clone()) {
Some(pool) => { Some(pool) => {
if !admin { if !admin {
pool.stats pool.stats
@@ -659,6 +525,7 @@ where
application_name: application_name.to_string(), application_name: application_name.to_string(),
shutdown, shutdown,
connected_to_server: false, connected_to_server: false,
secret,
}) })
} }
@@ -693,6 +560,7 @@ where
application_name: String::from("undefined"), application_name: String::from("undefined"),
shutdown, shutdown,
connected_to_server: false, connected_to_server: false,
secret: None,
}) })
} }
@@ -1200,7 +1068,7 @@ where
/// Retrieve connection pool, if it exists. /// Retrieve connection pool, if it exists.
/// Return an error to the client otherwise. /// Return an error to the client otherwise.
async fn get_pool(&mut self) -> Result<ConnectionPool, Error> { async fn get_pool(&mut self) -> Result<ConnectionPool, Error> {
match get_pool(&self.pool_name, &self.username) { match get_pool(&self.pool_name, &self.username, self.secret.clone()) {
Some(pool) => Ok(pool), Some(pool) => Ok(pool),
None => { None => {
error_response( error_response(

View File

@@ -1,6 +1,6 @@
/// Parse the configuration file. /// Parse the configuration file.
use arc_swap::ArcSwap; use arc_swap::ArcSwap;
use log::{error, info}; use log::{error, info, warn};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use regex::Regex; use regex::Regex;
use serde_derive::{Deserialize, Serialize}; use serde_derive::{Deserialize, Serialize};
@@ -181,6 +181,26 @@ pub struct User {
pub pool_size: u32, pub pool_size: u32,
#[serde(default)] // 0 #[serde(default)] // 0
pub statement_timeout: u64, pub statement_timeout: u64,
pub secrets: Option<Vec<String>>,
}
impl User {
fn validate(&self) -> Result<(), Error> {
match self.secrets {
Some(ref secrets) => {
for secret in secrets.iter() {
if secret.len() < 16 {
warn!(
"[user: {}] Secret is too short (less than 16 characters)",
self.username
);
}
}
}
None => (),
}
Ok(())
}
} }
impl Default for User { impl Default for User {
@@ -190,6 +210,7 @@ impl Default for User {
password: None, password: None,
pool_size: 15, pool_size: 15,
statement_timeout: 0, statement_timeout: 0,
secrets: None,
} }
} }
} }
@@ -508,6 +529,10 @@ impl Pool {
None => None, None => None,
}; };
for user in self.users.iter() {
user.1.validate()?;
}
Ok(()) Ok(())
} }
} }
@@ -657,6 +682,11 @@ impl Config {
} }
} }
} }
/// Checks that we configured TLS.
pub fn tls_enabled(&self) -> bool {
self.general.tls_certificate.is_some() && self.general.tls_private_key.is_some()
}
} }
impl Default for Config { impl Default for Config {

View File

@@ -61,6 +61,7 @@ use std::sync::Arc;
use tokio::sync::broadcast; use tokio::sync::broadcast;
mod admin; mod admin;
mod auth;
mod auth_passthrough; mod auth_passthrough;
mod client; mod client;
mod config; mod config;

View File

@@ -1,7 +1,7 @@
/// Helper functions to send one-off protocol messages /// Helper functions to send one-off protocol messages
/// and handle TcpStream (TCP socket). /// and handle TcpStream (TCP socket).
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use log::error; use log::{debug, error};
use md5::{Digest, Md5}; use md5::{Digest, Md5};
use socket2::{SockRef, TcpKeepalive}; use socket2::{SockRef, TcpKeepalive};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
@@ -46,29 +46,6 @@ where
write_all(stream, auth_ok).await write_all(stream, auth_ok).await
} }
/// Generate md5 password challenge.
pub async fn md5_challenge<S>(stream: &mut S) -> Result<[u8; 4], Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
// let mut rng = rand::thread_rng();
let salt: [u8; 4] = [
rand::random(),
rand::random(),
rand::random(),
rand::random(),
];
let mut res = BytesMut::new();
res.put_u8(b'R');
res.put_i32(12);
res.put_i32(5); // MD5
res.put_slice(&salt[..]);
write_all(stream, res).await?;
Ok(salt)
}
/// Give the client the process_id and secret we generated /// Give the client the process_id and secret we generated
/// used in query cancellation. /// used in query cancellation.
pub async fn backend_key_data<S>( pub async fn backend_key_data<S>(
@@ -257,6 +234,8 @@ pub async fn md5_password_with_hash<S>(stream: &mut S, hash: &str, salt: &[u8])
where where
S: tokio::io::AsyncWrite + std::marker::Unpin, S: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
debug!("Sending hash {} to server", hash);
let password = md5_hash_second_pass(hash, salt); let password = md5_hash_second_pass(hash, salt);
let mut message = BytesMut::with_capacity(password.len() as usize + 5); let mut message = BytesMut::with_capacity(password.len() as usize + 5);

View File

@@ -34,7 +34,7 @@ impl MirroredClient {
None => (default, default, crate::config::Pool::default()), None => (default, default, crate::config::Pool::default()),
}; };
let identifier = PoolIdentifier::new(&self.database, &self.user.username); let identifier = PoolIdentifier::new(&self.database, &self.user.username, None);
let manager = ServerPool::new( let manager = ServerPool::new(
self.address.clone(), self.address.clone(),

View File

@@ -59,24 +59,22 @@ pub struct PoolIdentifier {
/// The username the client connects with. Each user gets its own pool. /// The username the client connects with. Each user gets its own pool.
pub user: String, pub user: String,
/// The client secret (password).
pub secret: Option<String>,
} }
impl PoolIdentifier { impl PoolIdentifier {
/// Create a new user/pool identifier. /// Create a new user/pool identifier.
pub fn new(db: &str, user: &str) -> PoolIdentifier { pub fn new(db: &str, user: &str, secret: Option<String>) -> PoolIdentifier {
PoolIdentifier { PoolIdentifier {
db: db.to_string(), db: db.to_string(),
user: user.to_string(), user: user.to_string(),
secret,
} }
} }
} }
impl From<&Address> for PoolIdentifier {
fn from(address: &Address) -> PoolIdentifier {
PoolIdentifier::new(&address.database, &address.username)
}
}
/// Pool settings. /// Pool settings.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct PoolSettings { pub struct PoolSettings {
@@ -210,224 +208,241 @@ impl ConnectionPool {
// There is one pool per database/user pair. // There is one pool per database/user pair.
for user in pool_config.users.values() { for user in pool_config.users.values() {
let old_pool_ref = get_pool(pool_name, &user.username); let mut secrets = match &user.secrets {
let identifier = PoolIdentifier::new(pool_name, &user.username); Some(_) => user
.secrets
match old_pool_ref { .as_ref()
Some(pool) => { .unwrap()
// If the pool hasn't changed, get existing reference and insert it into the new_pools. .iter()
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). .map(|secret| Some(secret.to_string()))
if pool.config_hash == new_pool_hash_value { .collect::<Vec<Option<String>>>(),
info!( None => vec![],
"[pool: {}][user: {}] has not changed",
pool_name, user.username
);
new_pools.insert(identifier.clone(), pool.clone());
continue;
}
}
None => (),
}
info!(
"[pool: {}][user: {}] creating new pool",
pool_name, user.username
);
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
let mut shard_ids = pool_config
.shards
.clone()
.into_keys()
.collect::<Vec<String>>();
let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone()));
// Allow the pool to be seen in statistics
pool_stats.register(pool_stats.clone());
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
let pool_auth_hash: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
for shard_idx in &shard_ids {
let shard = &pool_config.shards[shard_idx];
let mut pools = Vec::new();
let mut servers = Vec::new();
let mut replica_number = 0;
// Load Mirror settings
for (address_index, server) in shard.servers.iter().enumerate() {
let mut mirror_addresses = vec![];
if let Some(mirror_settings_vec) = &shard.mirrors {
for (mirror_idx, mirror_settings) in
mirror_settings_vec.iter().enumerate()
{
if mirror_settings.mirroring_target_index != address_index {
continue;
}
mirror_addresses.push(Address {
id: address_id,
database: shard.database.clone(),
host: mirror_settings.host.clone(),
port: mirror_settings.port,
role: server.role,
address_index: mirror_idx,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: vec![],
stats: Arc::new(AddressStats::default()),
});
address_id += 1;
}
}
let address = Address {
id: address_id,
database: shard.database.clone(),
host: server.host.clone(),
port: server.port,
role: server.role,
address_index,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: mirror_addresses,
stats: Arc::new(AddressStats::default()),
};
address_id += 1;
if server.role == Role::Replica {
replica_number += 1;
}
// We assume every server in the pool share user/passwords
let auth_passthrough = AuthPassthrough::from_pool_config(pool_config);
if let Some(apt) = &auth_passthrough {
match apt.fetch_hash(&address).await {
Ok(ok) => {
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) {
if ok != *pool_auth_hash_value {
warn!("Hash is not the same across shards of the same pool, client auth will \
be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database);
}
}
debug!("Hash obtained for {:?}", address);
{
let mut pool_auth_hash = pool_auth_hash.write();
*pool_auth_hash = Some(ok.clone());
}
},
Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err),
}
}
let manager = ServerPool::new(
address.clone(),
user.clone(),
&shard.database,
client_server_map.clone(),
pool_stats.clone(),
pool_auth_hash.clone(),
);
let connect_timeout = match pool_config.connect_timeout {
Some(connect_timeout) => connect_timeout,
None => config.general.connect_timeout,
};
let idle_timeout = match pool_config.idle_timeout {
Some(idle_timeout) => idle_timeout,
None => config.general.idle_timeout,
};
let pool = Pool::builder()
.max_size(user.pool_size)
.connection_timeout(std::time::Duration::from_millis(connect_timeout))
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
pools.push(pool);
servers.push(address);
}
shards.push(pools);
addresses.push(servers);
banlist.push(HashMap::new());
}
assert_eq!(shards.len(), addresses.len());
if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
info!(
"Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
pool_name, user.username
);
}
let pool = ConnectionPool {
databases: shards,
stats: pool_stats,
addresses,
banlist: Arc::new(RwLock::new(banlist)),
config_hash: new_pool_hash_value,
server_info: Arc::new(RwLock::new(BytesMut::new())),
auth_hash: pool_auth_hash,
settings: PoolSettings {
pool_mode: pool_config.pool_mode,
load_balancing_mode: pool_config.load_balancing_mode,
// shards: pool_config.shards.clone(),
shards: shard_ids.len(),
user: user.clone(),
default_role: match pool_config.default_role.as_str() {
"any" => None,
"replica" => Some(Role::Replica),
"primary" => Some(Role::Primary),
_ => unreachable!(),
},
query_parser_enabled: pool_config.query_parser_enabled,
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function,
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
healthcheck_delay: config.general.healthcheck_delay,
healthcheck_timeout: config.general.healthcheck_timeout,
ban_time: config.general.ban_time,
sharding_key_regex: pool_config
.sharding_key_regex
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
shard_id_regex: pool_config
.shard_id_regex
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
auth_query: pool_config.auth_query.clone(),
auth_query_user: pool_config.auth_query_user.clone(),
auth_query_password: pool_config.auth_query_password.clone(),
},
validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)),
paused_waiter: Arc::new(Notify::new()),
}; };
// Connect to the servers to make sure pool configuration is valid secrets.push(None);
// before setting it globally.
// Do this async and somewhere else, we don't have to wait here.
let mut validate_pool = pool.clone();
tokio::task::spawn(async move {
let _ = validate_pool.validate().await;
});
// There is one pool per database/user pair. for secret in secrets {
new_pools.insert(PoolIdentifier::new(pool_name, &user.username), pool); let old_pool_ref = get_pool(pool_name, &user.username, secret.clone());
let identifier = PoolIdentifier::new(pool_name, &user.username, secret.clone());
match old_pool_ref {
Some(pool) => {
// If the pool hasn't changed, get existing reference and insert it into the new_pools.
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
if pool.config_hash == new_pool_hash_value {
info!(
"[pool: {}][user: {}] has not changed",
pool_name, user.username
);
new_pools.insert(identifier.clone(), pool.clone());
continue;
}
}
None => (),
}
info!(
"[pool: {}][user: {}] creating new pool",
pool_name, user.username
);
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
let mut shard_ids = pool_config
.shards
.clone()
.into_keys()
.collect::<Vec<String>>();
let pool_stats = Arc::new(PoolStats::new(identifier, pool_config.clone()));
// Allow the pool to be seen in statistics
pool_stats.register(pool_stats.clone());
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
let pool_auth_hash: Arc<RwLock<Option<String>>> = Arc::new(RwLock::new(None));
for shard_idx in &shard_ids {
let shard = &pool_config.shards[shard_idx];
let mut pools = Vec::new();
let mut servers = Vec::new();
let mut replica_number = 0;
// Load Mirror settings
for (address_index, server) in shard.servers.iter().enumerate() {
let mut mirror_addresses = vec![];
if let Some(mirror_settings_vec) = &shard.mirrors {
for (mirror_idx, mirror_settings) in
mirror_settings_vec.iter().enumerate()
{
if mirror_settings.mirroring_target_index != address_index {
continue;
}
mirror_addresses.push(Address {
id: address_id,
database: shard.database.clone(),
host: mirror_settings.host.clone(),
port: mirror_settings.port,
role: server.role,
address_index: mirror_idx,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: vec![],
stats: Arc::new(AddressStats::default()),
});
address_id += 1;
}
}
let address = Address {
id: address_id,
database: shard.database.clone(),
host: server.host.clone(),
port: server.port,
role: server.role,
address_index,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
username: user.username.clone(),
pool_name: pool_name.clone(),
mirrors: mirror_addresses,
stats: Arc::new(AddressStats::default()),
};
address_id += 1;
if server.role == Role::Replica {
replica_number += 1;
}
// We assume every server in the pool share user/passwords
let auth_passthrough = AuthPassthrough::from_pool_config(pool_config);
if let Some(apt) = &auth_passthrough {
match apt.fetch_hash(&address).await {
Ok(ok) => {
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) {
if ok != *pool_auth_hash_value {
warn!("Hash is not the same across shards of the same pool, client auth will \
be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database);
}
}
debug!("Hash obtained for {:?}", address);
{
let mut pool_auth_hash = pool_auth_hash.write();
*pool_auth_hash = Some(ok.clone());
}
},
Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err),
}
}
let manager = ServerPool::new(
address.clone(),
user.clone(),
&shard.database,
client_server_map.clone(),
pool_stats.clone(),
pool_auth_hash.clone(),
);
let connect_timeout = match pool_config.connect_timeout {
Some(connect_timeout) => connect_timeout,
None => config.general.connect_timeout,
};
let idle_timeout = match pool_config.idle_timeout {
Some(idle_timeout) => idle_timeout,
None => config.general.idle_timeout,
};
let pool = Pool::builder()
.max_size(user.pool_size)
.connection_timeout(std::time::Duration::from_millis(
connect_timeout,
))
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
pools.push(pool);
servers.push(address);
}
shards.push(pools);
addresses.push(servers);
banlist.push(HashMap::new());
}
assert_eq!(shards.len(), addresses.len());
if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) {
info!(
"Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}",
pool_name, user.username
);
}
let pool = ConnectionPool {
databases: shards,
stats: pool_stats,
addresses,
banlist: Arc::new(RwLock::new(banlist)),
config_hash: new_pool_hash_value,
server_info: Arc::new(RwLock::new(BytesMut::new())),
auth_hash: pool_auth_hash,
settings: PoolSettings {
pool_mode: pool_config.pool_mode,
load_balancing_mode: pool_config.load_balancing_mode,
// shards: pool_config.shards.clone(),
shards: shard_ids.len(),
user: user.clone(),
default_role: match pool_config.default_role.as_str() {
"any" => None,
"replica" => Some(Role::Replica),
"primary" => Some(Role::Primary),
_ => unreachable!(),
},
query_parser_enabled: pool_config.query_parser_enabled,
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function,
automatic_sharding_key: pool_config.automatic_sharding_key.clone(),
healthcheck_delay: config.general.healthcheck_delay,
healthcheck_timeout: config.general.healthcheck_timeout,
ban_time: config.general.ban_time,
sharding_key_regex: pool_config
.sharding_key_regex
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
shard_id_regex: pool_config
.shard_id_regex
.clone()
.map(|regex| Regex::new(regex.as_str()).unwrap()),
regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000),
auth_query: pool_config.auth_query.clone(),
auth_query_user: pool_config.auth_query_user.clone(),
auth_query_password: pool_config.auth_query_password.clone(),
},
validated: Arc::new(AtomicBool::new(false)),
paused: Arc::new(AtomicBool::new(false)),
paused_waiter: Arc::new(Notify::new()),
};
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
// Do this async and somewhere else, we don't have to wait here.
let mut validate_pool = pool.clone();
tokio::task::spawn(async move {
let _ = validate_pool.validate().await;
});
// There is one pool per database/user pair.
new_pools.insert(PoolIdentifier::new(pool_name, &user.username, secret), pool);
}
} }
} }
@@ -924,10 +939,10 @@ impl ManageConnection for ServerPool {
} }
/// Get the connection pool /// Get the connection pool
pub fn get_pool(db: &str, user: &str) -> Option<ConnectionPool> { pub fn get_pool(db: &str, user: &str, secret: Option<String>) -> Option<ConnectionPool> {
(*(*POOLS.load())) let identifier = PoolIdentifier::new(db, user, secret);
.get(&PoolIdentifier::new(db, user))
.cloned() (*(*POOLS.load())).get(&identifier).cloned()
} }
/// Get a pointer to all configured pools. /// Get a pointer to all configured pools.

View File

@@ -9,7 +9,7 @@ use std::sync::atomic::Ordering;
use std::sync::Arc; use std::sync::Arc;
use crate::config::Address; use crate::config::Address;
use crate::pool::get_all_pools; use crate::pool::{get_all_pools, PoolIdentifier};
use crate::stats::{get_pool_stats, get_server_stats, ServerStats}; use crate::stats::{get_pool_stats, get_server_stats, ServerStats};
struct MetricHelpType { struct MetricHelpType {
@@ -233,10 +233,10 @@ impl<Value: fmt::Display> PrometheusMetric<Value> {
Self::from_name(&format!("stats_{}", name), value, labels) Self::from_name(&format!("stats_{}", name), value, labels)
} }
fn from_pool(pool: &(String, String), name: &str, value: u64) -> Option<PrometheusMetric<u64>> { fn from_pool(pool: &PoolIdentifier, name: &str, value: u64) -> Option<PrometheusMetric<u64>> {
let mut labels = HashMap::new(); let mut labels = HashMap::new();
labels.insert("pool", pool.0.clone()); labels.insert("pool", pool.db.clone());
labels.insert("user", pool.1.clone()); labels.insert("user", pool.user.clone());
Self::from_name(&format!("pools_{}", name), value, labels) Self::from_name(&format!("pools_{}", name), value, labels)
} }
@@ -294,7 +294,7 @@ fn push_pool_stats(lines: &mut Vec<String>) {
} else { } else {
warn!( warn!(
"Metric {} not implemented for ({},{})", "Metric {} not implemented for ({},{})",
name, pool.0, pool.1 name, pool.db, pool.user
); );
} }
} }

View File

@@ -755,7 +755,9 @@ impl Server {
Arc::new(RwLock::new(None)), Arc::new(RwLock::new(None)),
) )
.await?; .await?;
debug!("Connected!, sending query.");
debug!("Connected!, sending query: {}", query);
server.send(&simple_query(query)).await?; server.send(&simple_query(query)).await?;
let mut message = server.recv().await?; let mut message = server.recv().await?;
@@ -764,6 +766,8 @@ impl Server {
} }
async fn parse_query_message(message: &mut BytesMut) -> Result<Vec<String>, Error> { async fn parse_query_message(message: &mut BytesMut) -> Result<Vec<String>, Error> {
debug!("Parsing query message");
let mut pair = Vec::<String>::new(); let mut pair = Vec::<String>::new();
match message::backend::Message::parse(message) { match message::backend::Message::parse(message) {
Ok(Some(message::backend::Message::RowDescription(_description))) => {} Ok(Some(message::backend::Message::RowDescription(_description))) => {}
@@ -833,6 +837,9 @@ async fn parse_query_message(message: &mut BytesMut) -> Result<Vec<String>, Erro
} }
}; };
} }
debug!("Got auth hash successfully");
Ok(pair) Ok(pair)
} }

View File

@@ -22,7 +22,7 @@ pub use server::{ServerState, ServerStats};
/// Convenience types for various stats /// Convenience types for various stats
type ClientStatesLookup = HashMap<i32, Arc<ClientStats>>; type ClientStatesLookup = HashMap<i32, Arc<ClientStats>>;
type ServerStatesLookup = HashMap<i32, Arc<ServerStats>>; type ServerStatesLookup = HashMap<i32, Arc<ServerStats>>;
type PoolStatsLookup = HashMap<(String, String), Arc<PoolStats>>; type PoolStatsLookup = HashMap<PoolIdentifier, Arc<PoolStats>>;
/// Stats for individual client connections /// Stats for individual client connections
/// Used in SHOW CLIENTS. /// Used in SHOW CLIENTS.
@@ -83,9 +83,7 @@ impl Reporter {
/// Register a pool with the stats system. /// Register a pool with the stats system.
fn pool_register(&self, identifier: PoolIdentifier, stats: Arc<PoolStats>) { fn pool_register(&self, identifier: PoolIdentifier, stats: Arc<PoolStats>) {
POOL_STATS POOL_STATS.write().insert(identifier, stats);
.write()
.insert((identifier.db, identifier.user), stats);
} }
} }

View File

@@ -102,6 +102,13 @@ impl PoolStats {
self.identifier.user.clone() self.identifier.user.clone()
} }
pub fn redacted_secret(&self) -> String {
match self.identifier.secret {
Some(ref s) => format!("****{}", &s[s.len() - 4..]),
None => "<no secret>".to_string(),
}
}
pub fn pool_mode(&self) -> PoolMode { pub fn pool_mode(&self) -> PoolMode {
self.config.pool_mode self.config.pool_mode
} }

View File

@@ -67,7 +67,7 @@ describe "Auth Query" do
end end
context 'and with cleartext passwords not set' do context 'and with cleartext passwords not set' do
let(:config_user) { { 'username' => 'sharding_user', 'password' => 'sharding_user' } } let(:config_user) { { 'username' => 'sharding_user' } }
it 'it uses obtained passwords' do it 'it uses obtained passwords' do
connection_string = processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']) connection_string = processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])
@@ -76,7 +76,7 @@ describe "Auth Query" do
end end
it 'allows passwords to be changed without closing existing connections' do it 'allows passwords to be changed without closing existing connections' do
pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'])) pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']))
expect(pgconn.exec("SELECT 1 + 2")).not_to be_nil expect(pgconn.exec("SELECT 1 + 2")).not_to be_nil
Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD 'secret2';") Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD 'secret2';")
expect(pgconn.exec("SELECT 1 + 4")).not_to be_nil expect(pgconn.exec("SELECT 1 + 4")).not_to be_nil
@@ -84,7 +84,7 @@ describe "Auth Query" do
end end
it 'allows passwords to be changed and that new password is needed when reconnecting' do it 'allows passwords to be changed and that new password is needed when reconnecting' do
pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'])) pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']))
expect(pgconn.exec("SELECT 1 + 2")).not_to be_nil expect(pgconn.exec("SELECT 1 + 2")).not_to be_nil
Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD 'secret2';") Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD 'secret2';")
newconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], 'secret2')) newconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], 'secret2'))

39
tests/ruby/auth_spec.rb Normal file
View File

@@ -0,0 +1,39 @@
# frozen_string_literal: true
require_relative 'spec_helper'
describe "Authentication" do
describe "multiple secrets configured" do
let(:secrets) { ["one_secret", "two_secret"] }
let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5, pool_mode="transaction", lb_mode="random", log_level="info", secrets=["one_secret", "two_secret"]) }
after do
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
it "can connect using all secrets and postgres password" do
secrets.push("sharding_user").each do |secret|
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user", password=secret))
conn.exec("SELECT current_user")
end
end
end
describe "no secrets configured" do
let(:secrets) { [] }
let(:processes) { Helpers::Pgcat.three_shard_setup("sharded_db", 5, pool_mode="transaction", lb_mode="random", log_level="info") }
after do
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
it "can connect using only the password" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.exec("SELECT current_user")
expect { PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user", password="secret_one")) }.to raise_error PG::ConnectionBad
end
end
end

View File

@@ -12,14 +12,18 @@ end
module Helpers module Helpers
module Pgcat module Pgcat
def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info") def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info", secrets=nil)
user = { user = {
"password" => "sharding_user", "password" => "sharding_user",
"pool_size" => pool_size, "pool_size" => pool_size,
"statement_timeout" => 0, "statement_timeout" => 0,
"username" => "sharding_user" "username" => "sharding_user",
} }
if !secrets.nil?
user["secrets"] = secrets
end
pgcat = PgcatProcess.new(log_level) pgcat = PgcatProcess.new(log_level)
primary0 = PgInstance.new(5432, user["username"], user["password"], "shard0") primary0 = PgInstance.new(5432, user["username"], user["password"], "shard0")
primary1 = PgInstance.new(7432, user["username"], user["password"], "shard1") primary1 = PgInstance.new(7432, user["username"], user["password"], "shard1")
@@ -27,7 +31,7 @@ module Helpers
pgcat_cfg = pgcat.current_config pgcat_cfg = pgcat.current_config
pgcat_cfg["pools"] = { pgcat_cfg["pools"] = {
"#{pool_name}" => { "#{pool_name}" => {
"default_role" => "any", "default_role" => "any",
"pool_mode" => pool_mode, "pool_mode" => pool_mode,
"load_balancing_mode" => lb_mode, "load_balancing_mode" => lb_mode,
@@ -41,8 +45,14 @@ module Helpers
"2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_s, "primary"]] }, "2" => { "database" => "shard2", "servers" => [["localhost", primary2.port.to_s, "primary"]] },
}, },
"users" => { "0" => user } "users" => { "0" => user }
} },
} }
if !secrets.nil?
pgcat_cfg["general"]["tls_certificate"] = "../../.circleci/server.cert"
pgcat_cfg["general"]["tls_private_key"] = "../../.circleci/server.key"
end
pgcat.update_config(pgcat_cfg) pgcat.update_config(pgcat_cfg)
pgcat.start pgcat.start

View File

@@ -78,7 +78,6 @@ class PgcatProcess
10.times do 10.times do
Process.kill 0, @pid Process.kill 0, @pid
PG::connect(connection_string || example_connection_string).close PG::connect(connection_string || example_connection_string).close
return self return self
rescue Errno::ESRCH rescue Errno::ESRCH
raise StandardError, "Process #{@pid} died. #{logs}" raise StandardError, "Process #{@pid} died. #{logs}"
@@ -112,10 +111,13 @@ class PgcatProcess
"postgresql://#{username}:#{password}@0.0.0.0:#{@port}/pgcat" "postgresql://#{username}:#{password}@0.0.0.0:#{@port}/pgcat"
end end
def connection_string(pool_name, username, password = nil) def connection_string(pool_name, username, password=nil)
cfg = current_config cfg = current_config
user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username } user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username }
"postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}"
password = if password.nil? then user_obj["password"] else password end
"postgresql://#{username}:#{password}@0.0.0.0:#{@port}/#{pool_name}"
end end
def example_connection_string def example_connection_string