From 6f768a84ce8acfb3a7e861fa0dd92d31b3c6afce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jose=20Fern=C3=A1ndez?= Date: Thu, 30 Mar 2023 22:29:23 +0200 Subject: [PATCH] Auth passthrough (auth_query) (#266) * Add a new exec_simple_query method This adds a new `exec_simple_query` method so we can make 'out of band' queries to servers that don't interfere with pools at all. In order to reuse startup code for making these simple queries, we need to set the stats (`Reporter`) optional, so using these simple queries wont interfere with stats. * Add auth passthough (auth_query) Adds a feature that allows setting auth passthrough for md5 auth. It adds 3 new (general and pool) config parameters: - `auth_query`: An string containing a query that will be executed on boot to obtain the hash of a given user. This query have to use a placeholder `$1`, so pgcat can replace it with the user its trying to fetch the hash from. - `auth_query_user`: The user to use for connecting to the server and executing the auth_query. - `auth_query_password`: The password to use for connecting to the server and executing the auth_query. The configuration can be done either on the general config (so pools share them) or in a per-pool basis. The behavior is, at boot time, when validating server connections, a hash is fetched per server and stored in the pool. When new server connections are created, and no cleartext password is specified, the obtained hash is used for creating them, if the hash could not be obtained for whatever reason, it retries it. When client authentication is tried, it uses cleartext passwords if specified, it not, it checks whether we have query_auth set up, if so, it tries to use the obtained hash for making client auth. If there is no hash (we could not obtain one when validating the connection), a new fetch is tried. Once we have a hash, we authenticate using it against whathever the client has sent us, if there is a failure we refetch the hash and retry auth (so password changes can be done). The idea with this 'retrial' mechanism is to make it fault tolerant, so if for whatever reason hash could not be obtained during connection validation, or the password has change, we can still connect later. * Add documentation for Auth passthrough --- .circleci/config.yml | 8 + .circleci/run_tests.sh | 1 + .rustfmt.toml | 2 + CONFIG.md | 56 +++++- Cargo.lock | 44 ++++- Cargo.toml | 3 +- README.md | 1 + dev/docker-compose.yaml | 8 + src/auth_passthrough.rs | 107 ++++++++++++ src/client.rs | 82 ++++++++- src/config.rs | 89 +++++++++- src/errors.rs | 2 + src/lib.rs | 1 + src/main.rs | 1 + src/messages.rs | 22 ++- src/mirrors.rs | 2 + src/pool.rs | 53 +++++- src/query_router.rs | 6 + src/server.rs | 150 ++++++++++++++++- tests/docker/docker-compose.yml | 9 + tests/ruby/auth_query_spec.rb | 215 ++++++++++++++++++++++++ tests/ruby/helpers/auth_query_helper.rb | 173 +++++++++++++++++++ tests/ruby/helpers/pgcat_helper.rb | 7 + tests/ruby/helpers/pgcat_process.rb | 15 +- 24 files changed, 1026 insertions(+), 31 deletions(-) create mode 100644 .rustfmt.toml create mode 100644 src/auth_passthrough.rs create mode 100644 tests/ruby/auth_query_spec.rb create mode 100644 tests/ruby/helpers/auth_query_helper.rb diff --git a/.circleci/config.yml b/.circleci/config.yml index 5e2d114..c7f5c9f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -46,6 +46,14 @@ jobs: POSTGRES_PASSWORD: postgres POSTGRES_INITDB_ARGS: --auth-local=scram-sha-256 --auth-host=scram-sha-256 --auth=scram-sha-256 + - image: postgres:14 + command: ["postgres", "-p", "10432", "-c", "shared_preload_libraries=pg_stat_statements"] + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5 + # Add steps to the job # See: https://circleci.com/docs/2.0/configuration-reference/#steps steps: diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index a5cfab0..4ba497c 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -19,6 +19,7 @@ PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 5432 -U postgres -f tests/sharding/q PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 7432 -U postgres -f tests/sharding/query_routing_setup.sql PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 8432 -U postgres -f tests/sharding/query_routing_setup.sql PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 9432 -U postgres -f tests/sharding/query_routing_setup.sql +PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 10432 -U postgres -f tests/sharding/query_routing_setup.sql PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard0 -i PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard1 -i diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000..17f3321 --- /dev/null +++ b/.rustfmt.toml @@ -0,0 +1,2 @@ +edition = "2021" +hard_tabs = false diff --git a/CONFIG.md b/CONFIG.md index bcd6f09..3cec253 100644 --- a/CONFIG.md +++ b/CONFIG.md @@ -175,11 +175,41 @@ Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DAT ### admin_password ``` path: general.admin_password -default: "admin_pass" +default: ``` Password to access the virtual administrative database +### auth_query (experimental) +``` +path: general.auth_query +default: +``` + +Query to be sent to servers to obtain the hash used for md5 authentication. The connection will be +established using the database configured in the pool. This parameter is inherited by every pool +and can be redefined in pool configuration. + +### auth_query_user (experimental) +``` +path: general.auth_query_user +default: +``` + +User to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query +specified in `auth_query_user`. The connection will be established using the database configured in the pool. +This parameter is inherited by every pool and can be redefined in pool configuration. + +### auth_query_password (experimental) +``` +path: general.auth_query_password +default: +``` + +Password to be used for connecting to servers to obtain the hash used for md5 authentication by sending the query +specified in `auth_query_user`. The connection will be established using the database configured in the pool. +This parameter is inherited by every pool and can be redefined in pool configuration. + ## `pools.` Section ### pool_mode @@ -281,6 +311,30 @@ default: 3000 Connect timeout can be overwritten in the pool +### auth_query (experimental) +``` +path: general.auth_query +default: +``` + +Auth query can be overwritten in the pool + +### auth_query_user (experimental) +``` +path: general.auth_query_user +default: +``` + +Auth query user can be overwritten in the pool + +### auth_query_password (experimental) +``` +path: general.auth_query_password +default: +``` + +Auth query password can be overwritten in the pool + ## `pools..users.` Section ### username diff --git a/Cargo.lock b/Cargo.lock index a4e5346..ad7bb8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -45,7 +45,7 @@ checksum = "6227a8d6fdb862bcb100c4314d0d9579e5cd73fa6df31a2e6f6e1acd3c5f1207" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 1.0.109", ] [[package]] @@ -54,6 +54,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.21.0" @@ -94,6 +100,12 @@ version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + [[package]] name = "bytes" version = "1.4.0" @@ -257,6 +269,12 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de853764b47027c2e862a995c34978ffa63c1501f2e15f987ba11bd4f9bba193" +[[package]] +name = "fallible-iterator" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" + [[package]] name = "fnv" version = "1.0.7" @@ -732,12 +750,13 @@ dependencies = [ "arc-swap", "async-trait", "atomic_enum", - "base64", + "base64 0.21.0", "bb8", "bytes", "chrono", "env_logger", "exitcode", + "fallible-iterator", "futures", "hmac", "hyper", @@ -749,6 +768,7 @@ dependencies = [ "once_cell", "parking_lot", "phf", + "postgres-protocol", "rand", "regex", "rustls-pemfile", @@ -818,6 +838,24 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "postgres-protocol" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "878c6cbf956e03af9aa8204b407b9cbf47c072164800aa918c516cd4b056c50c" +dependencies = [ + "base64 0.13.1", + "byteorder", + "bytes", + "fallible-iterator", + "hmac", + "md-5", + "memchr", + "rand", + "sha2", + "stringprep", +] + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -945,7 +983,7 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ - "base64", + "base64 0.21.0", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 3f0cbec..89cfe64 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,6 @@ version = "1.0.0" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - [dependencies] tokio = { version = "1", features = ["full"] } bytes = "1" @@ -38,6 +37,8 @@ futures = "0.3" socket2 = { version = "0.4.7", features = ["all"] } nix = "0.26.2" atomic_enum = "0.2.0" +postgres-protocol = "0.6.4" +fallible-iterator = "0.2" [target.'cfg(not(target_env = "msvc"))'.dependencies] jemallocator = "0.5.0" diff --git a/README.md b/README.md index 4d6f599..63b5ab1 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ PostgreSQL pooler and proxy (like PgBouncer) with support for sharding, load bal | Sharding using comments parsing/Regex | **Experimental** | Clients can include shard information (sharding key, shard ID) in the query comments. | | Automatic sharding | **Experimental** | PgCat can parse queries, detect sharding keys automatically, and route queries to the correct shard. | | Mirroring | **Experimental** | Mirror queries between multiple databases in order to test servers with realistic production traffic. | +| Auth passthrough | **Experimental** | MD5 password authentication can be configured to use an `auth_query` so no cleartext passwords are needed in the config file. | ## Status diff --git a/dev/docker-compose.yaml b/dev/docker-compose.yaml index 15621e8..71704bc 100644 --- a/dev/docker-compose.yaml +++ b/dev/docker-compose.yaml @@ -58,6 +58,13 @@ services: POSTGRES_INITDB_ARGS: --auth-local=scram-sha-256 --auth-host=scram-sha-256 --auth=scram-sha-256 PGPORT: 9432 command: ["postgres", "-p", "9432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"] + pg5: + <<: *common-definition-pg + environment: + <<: *common-env-pg + POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5 + PGPORT: 10432 + command: ["postgres", "-p", "5432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"] toxiproxy: build: . @@ -71,6 +78,7 @@ services: - pg2 - pg3 - pg4 + - pg5 pgcat-shell: stdin_open: true diff --git a/src/auth_passthrough.rs b/src/auth_passthrough.rs new file mode 100644 index 0000000..b9f0e97 --- /dev/null +++ b/src/auth_passthrough.rs @@ -0,0 +1,107 @@ +use crate::errors::Error; +use crate::server::Server; +use log::debug; + +#[derive(Clone, Debug)] +pub struct AuthPassthrough { + password: String, + query: String, + user: String, +} + +impl AuthPassthrough { + /// Initializes an AuthPassthrough. + pub fn new(query: &str, user: &str, password: &str) -> Self { + AuthPassthrough { + password: password.to_string(), + query: query.to_string(), + user: user.to_string(), + } + } + + /// Returns an AuthPassthrough given the pool configuration. + /// If any of required values is not set, None is returned. + pub fn from_pool_config(pool_config: &crate::config::Pool) -> Option { + if pool_config.is_auth_query_configured() { + return Some(AuthPassthrough::new( + pool_config.auth_query.as_ref().unwrap(), + pool_config.auth_query_user.as_ref().unwrap(), + pool_config.auth_query_password.as_ref().unwrap(), + )); + } + + None + } + + /// Returns an AuthPassthrough given the pool settings. + /// If any of required values is not set, None is returned. + pub fn from_pool_settings(pool_settings: &crate::pool::PoolSettings) -> Option { + let pool_config = crate::config::Pool { + auth_query: pool_settings.auth_query.clone(), + auth_query_password: pool_settings.auth_query_password.clone(), + auth_query_user: pool_settings.auth_query_user.clone(), + ..Default::default() + }; + + AuthPassthrough::from_pool_config(&pool_config) + } + + /// Connects to server and executes auth_query for the specified address. + /// If the response is a row with two columns containing the username set in the address. + /// and its MD5 hash, the MD5 hash returned. + /// + /// Note that the query is executed, changing $1 with the name of the user + /// this is so we only hold in memory (and transfer) the least amount of 'sensitive' data. + /// Also, it is compatible with pgbouncer. + /// + /// # Arguments + /// + /// * `address` - An Address of the server we want to connect to. The username for the hash will be obtained from this value. + /// + /// # Examples + /// + /// ``` + /// use pgcat::auth_passthrough::AuthPassthrough; + /// use pgcat::config::Address; + /// let auth_passthrough = AuthPassthrough::new("SELECT * FROM public.user_lookup('$1');", "postgres", "postgres"); + /// auth_passthrough.fetch_hash(&Address::default()); + /// ``` + /// + pub async fn fetch_hash(&self, address: &crate::config::Address) -> Result { + let auth_user = crate::config::User { + username: self.user.clone(), + password: Some(self.password.clone()), + pool_size: 1, + statement_timeout: 0, + }; + + let user = &address.username; + + debug!("Connecting to server to obtain auth hashes."); + let auth_query = self.query.replace("$1", user); + match Server::exec_simple_query(address, &auth_user, &auth_query).await { + Ok(password_data) => { + if password_data.len() == 2 && password_data.first().unwrap() == user { + if let Some(stripped_hash) = password_data.last().unwrap().to_string().strip_prefix("md5") { + Ok(stripped_hash.to_string()) + } + else { + Err(Error::AuthPassthroughError( + "Obtained hash from auth_query does not seem to be in md5 format.".to_string(), + )) + } + } else { + Err(Error::AuthPassthroughError( + "Data obtained from query does not follow the scheme 'user','hash'." + .to_string(), + )) + } + } + Err(err) => { + Err(Error::AuthPassthroughError( + format!("Error trying to obtain password from auth_query, ignoring hash for user '{}'. Error: {:?}", + user, err))) + } + } + } +} diff --git a/src/client.rs b/src/client.rs index f9f4e01..d75c069 100644 --- a/src/client.rs +++ b/src/client.rs @@ -12,9 +12,9 @@ use tokio::sync::broadcast::Receiver; use tokio::sync::mpsc::Sender; use crate::admin::{generate_server_info_for_admin, handle_admin}; +use crate::auth_passthrough::AuthPassthrough; use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode}; use crate::constants::*; - use crate::messages::*; use crate::pool::{get_pool, ClientServerMap, ConnectionPool}; use crate::query_router::{Command, QueryRouter}; @@ -377,6 +377,20 @@ pub async fn startup_tls( } } +async fn refetch_auth_hash(pool: &ConnectionPool) -> Result { + let address = pool.address(0, 0); + if let Some(apt) = AuthPassthrough::from_pool_settings(&pool.settings) { + let hash = apt.fetch_hash(address).await?; + + return Ok(hash); + } + + Err(Error::ClientError(format!( + "Could not obtain hash for {{ username: {:?}, database: {:?} }}. Auth passthrough not enabled.", + address.username, address.database + ))) +} + impl Client where S: tokio::io::AsyncRead + std::marker::Unpin, @@ -509,14 +523,68 @@ where } }; - // Compare server and client hashes. - let password_hash = md5_hash_password(username, &pool.settings.user.password, &salt); + // Obtain the hash to compare, we give preference to that written in cleartext in config + // if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained + // when the pool was created. If there is no hash there, we try to fetch it one more time. + let password_hash = if let Some(password) = &pool.settings.user.password { + Some(md5_hash_password(username, password, &salt)) + } else { + if !get_config().is_auth_query_configured() { + return Err(Error::ClientError(format!("Client auth not possible, no cleartext password set for username: {:?} in config and auth passthrough (query_auth) is not set up.", username))); + } - if password_hash != password_response { - warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name); - wrong_password(&mut write, username).await?; + let mut hash = (*pool.auth_hash.read()).clone(); - return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); + if hash.is_none() { + warn!("Query auth configured but no hash password found for pool {}. Will try to refetch it.", pool_name); + match refetch_auth_hash(&pool).await { + Ok(fetched_hash) => { + warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, obtained. Updating.", username, pool_name, application_name); + { + let mut pool_auth_hash = pool.auth_hash.write(); + *pool_auth_hash = Some(fetched_hash.clone()); + } + + hash = Some(fetched_hash); + } + Err(err) => { + return Err( + Error::ClientError( + format!("No cleartext password set, and no auth passthrough could not obtain the hash from server for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, the error was: {:?}", + username, + pool_name, + application_name, + err) + ) + ); + } + } + }; + + Some(md5_hash_second_pass(&hash.unwrap(), &salt)) + }; + + // Once we have the resulting hash, we compare with what the client gave us. + // If they do not match and auth query is set up, we try to refetch the hash one more time + // to see if the password has changed since the pool was created. + // + // @TODO: we could end up fetching again the same password twice (see above). + if password_hash.unwrap() != password_response { + warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, will try to refetch it.", username, pool_name, application_name); + let fetched_hash = refetch_auth_hash(&pool).await?; + let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt); + + // Ok password changed in server an auth is possible. + if new_password_hash == password_response { + warn!("Password for {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}, changed in server. Updating.", username, pool_name, application_name); + { + let mut pool_auth_hash = pool.auth_hash.write(); + *pool_auth_hash = Some(fetched_hash); + } + } else { + wrong_password(&mut write, username).await?; + return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", username, pool_name, application_name))); + } } let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction; diff --git a/src/config.rs b/src/config.rs index 644532a..6545457 100644 --- a/src/config.rs +++ b/src/config.rs @@ -177,7 +177,7 @@ impl Address { #[derive(Clone, PartialEq, Hash, Eq, Serialize, Deserialize, Debug)] pub struct User { pub username: String, - pub password: String, + pub password: Option, pub pool_size: u32, #[serde(default)] // 0 pub statement_timeout: u64, @@ -187,7 +187,7 @@ impl Default for User { fn default() -> User { User { username: String::from("postgres"), - password: String::new(), + password: None, pool_size: 15, statement_timeout: 0, } @@ -250,6 +250,10 @@ pub struct General { pub tls_private_key: Option, pub admin_username: String, pub admin_password: String, + + pub auth_query: Option, + pub auth_query_user: Option, + pub auth_query_password: Option, } impl General { @@ -334,6 +338,9 @@ impl Default for General { tls_private_key: None, admin_username: String::from("admin"), admin_password: String::from("admin"), + auth_query: None, + auth_query_user: None, + auth_query_password: None, } } } @@ -406,6 +413,10 @@ pub struct Pool { pub shard_id_regex: Option, pub regex_search_limit: Option, + pub auth_query: Option, + pub auth_query_user: Option, + pub auth_query_password: Option, + pub shards: BTreeMap, pub users: BTreeMap, // Note, don't put simple fields below these configs. There's a compatability issue with TOML that makes it @@ -420,6 +431,12 @@ impl Pool { s.finish() } + pub fn is_auth_query_configured(&self) -> bool { + self.auth_query_password.is_some() + && self.auth_query_user.is_some() + && self.auth_query_password.is_some() + } + pub fn default_pool_mode() -> PoolMode { PoolMode::Transaction } @@ -512,6 +529,9 @@ impl Default for Pool { sharding_key_regex: None, shard_id_regex: None, regex_search_limit: Some(1000), + auth_query: None, + auth_query_user: None, + auth_query_password: None, } } } @@ -612,9 +632,31 @@ pub struct Config { } impl Config { + pub fn is_auth_query_configured(&self) -> bool { + self.pools + .iter() + .any(|(_name, pool)| pool.is_auth_query_configured()) + } + pub fn default_path() -> String { String::from("pgcat.toml") } + + pub fn fill_up_auth_query_config(&mut self) { + for (_name, pool) in self.pools.iter_mut() { + if pool.auth_query.is_none() { + pool.auth_query = self.general.auth_query.clone(); + } + + if pool.auth_query_user.is_none() { + pool.auth_query_user = self.general.auth_query_user.clone(); + } + + if pool.auth_query_password.is_none() { + pool.auth_query_password = self.general.auth_query_password.clone(); + } + } + } } impl Default for Config { @@ -832,6 +874,35 @@ impl Config { } pub fn validate(&mut self) -> Result<(), Error> { + // Validation for auth_query feature + if self.general.auth_query.is_some() + && (self.general.auth_query_user.is_none() + || self.general.auth_query_password.is_none()) + { + error!("If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`"); + return Err(Error::BadConfig); + } + + for (name, pool) in self.pools.iter() { + if pool.auth_query.is_some() + && (pool.auth_query_user.is_none() || pool.auth_query_password.is_none()) + { + error!("Error in pool {{ {} }}. If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`", name); + return Err(Error::BadConfig); + } + + for (_name, user_data) in pool.users.iter() { + if (pool.auth_query.is_none() + || pool.auth_query_password.is_none() + || pool.auth_query_user.is_none()) + && user_data.password.is_none() + { + error!("Error in pool {{ {} }}. You have to specify a user password for every pool if auth_query is not specified", name); + return Err(Error::BadConfig); + } + } + } + // Validate TLS! match self.general.tls_certificate.clone() { Some(tls_certificate) => { @@ -911,6 +982,7 @@ pub async fn parse(path: &str) -> Result<(), Error> { } }; + config.fill_up_auth_query_config(); config.validate()?; config.path = path.to_string(); @@ -980,7 +1052,10 @@ mod test { "sharding_user" ); assert_eq!( - get_config().pools["sharded_db"].users["1"].password, + get_config().pools["sharded_db"].users["1"] + .password + .as_ref() + .unwrap(), "other_user" ); assert_eq!(get_config().pools["sharded_db"].users["1"].pool_size, 21); @@ -1005,10 +1080,16 @@ mod test { "simple_user" ); assert_eq!( - get_config().pools["simple_db"].users["0"].password, + get_config().pools["simple_db"].users["0"] + .password + .as_ref() + .unwrap(), "simple_user" ); assert_eq!(get_config().pools["simple_db"].users["0"].pool_size, 5); + assert_eq!(get_config().general.auth_query, None); + assert_eq!(get_config().general.auth_query_user, None); + assert_eq!(get_config().general.auth_query_password, None); } #[tokio::test] diff --git a/src/errors.rs b/src/errors.rs index 310243c..58fc088 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -15,4 +15,6 @@ pub enum Error { StatementTimeout, ShuttingDown, ParseBytesError(String), + AuthError(String), + AuthPassthroughError(String), } diff --git a/src/lib.rs b/src/lib.rs index 67aa9cb..2645cd4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +pub mod auth_passthrough; pub mod config; pub mod constants; pub mod errors; diff --git a/src/main.rs b/src/main.rs index a59da21..4c8987f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -61,6 +61,7 @@ use std::sync::Arc; use tokio::sync::broadcast; mod admin; +mod auth_passthrough; mod client; mod config; mod constants; diff --git a/src/messages.rs b/src/messages.rs index c9ace4e..61c36c6 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -213,7 +213,13 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec { let output = md5.finalize_reset(); // Second pass - md5.update(format!("{:x}", output)); + md5_hash_second_pass(&(format!("{:x}", output)), salt) +} + +pub fn md5_hash_second_pass(hash: &str, salt: &[u8]) -> Vec { + let mut md5 = Md5::new(); + // Second pass + md5.update(hash); md5.update(salt); let mut password = format!("md5{:x}", md5.finalize()) @@ -247,6 +253,20 @@ where write_all(stream, message).await } +pub async fn md5_password_with_hash(stream: &mut S, hash: &str, salt: &[u8]) -> Result<(), Error> +where + S: tokio::io::AsyncWrite + std::marker::Unpin, +{ + let password = md5_hash_second_pass(hash, salt); + let mut message = BytesMut::with_capacity(password.len() as usize + 5); + + message.put_u8(b'p'); + message.put_i32(password.len() as i32 + 4); + message.put_slice(&password[..]); + + write_all(stream, message).await +} + /// Implements a response to our custom `SET SHARDING KEY` /// and `SET SERVER ROLE` commands. /// This tells the client we're ready for the next query. diff --git a/src/mirrors.rs b/src/mirrors.rs index 128fe22..17f91d4 100644 --- a/src/mirrors.rs +++ b/src/mirrors.rs @@ -4,6 +4,7 @@ use std::sync::Arc; /// Packets arrive to us through a channel from the main client and we send them to the server. use bb8::Pool; use bytes::{Bytes, BytesMut}; +use parking_lot::RwLock; use crate::config::{get_config, Address, Role, User}; use crate::pool::{ClientServerMap, PoolIdentifier, ServerPool}; @@ -41,6 +42,7 @@ impl MirroredClient { self.database.as_str(), ClientServerMap::default(), Arc::new(PoolStats::new(identifier, cfg.clone())), + Arc::new(RwLock::new(None)), ); Pool::builder() diff --git a/src/pool.rs b/src/pool.rs index f6f9118..e1ab7cb 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -20,6 +20,7 @@ use tokio::sync::Notify; use crate::config::{get_config, Address, General, LoadBalancingMode, PoolMode, Role, User}; use crate::errors::Error; +use crate::auth_passthrough::AuthPassthrough; use crate::server::Server; use crate::sharding::ShardingFunction; use crate::stats::{AddressStats, ClientStats, PoolStats, ServerStats}; @@ -123,6 +124,11 @@ pub struct PoolSettings { // Limit how much of each query is searched for a potential shard regex match pub regex_search_limit: usize, + + // Auth query parameters + pub auth_query: Option, + pub auth_query_user: Option, + pub auth_query_password: Option, } impl Default for PoolSettings { @@ -143,6 +149,9 @@ impl Default for PoolSettings { sharding_key_regex: None, shard_id_regex: None, regex_search_limit: 1000, + auth_query: None, + auth_query_user: None, + auth_query_password: None, } } } @@ -183,6 +192,9 @@ pub struct ConnectionPool { paused_waiter: Arc, pub stats: Arc, + + /// AuthInfo + pub auth_hash: Arc>>, } impl ConnectionPool { @@ -237,6 +249,7 @@ impl ConnectionPool { // Sort by shard number to ensure consistency. shard_ids.sort_by_key(|k| k.parse::().unwrap()); + let pool_auth_hash: Arc>> = Arc::new(RwLock::new(None)); for shard_idx in &shard_ids { let shard = &pool_config.shards[shard_idx]; @@ -293,12 +306,35 @@ impl ConnectionPool { replica_number += 1; } + // We assume every server in the pool share user/passwords + let auth_passthrough = AuthPassthrough::from_pool_config(pool_config); + + if let Some(apt) = &auth_passthrough { + match apt.fetch_hash(&address).await { + Ok(ok) => { + if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) { + if ok != *pool_auth_hash_value { + warn!("Hash is not the same across shards of the same pool, client auth will \ + be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database); + } + } + debug!("Hash obtained for {:?}", address); + { + let mut pool_auth_hash = pool_auth_hash.write(); + *pool_auth_hash = Some(ok.clone()); + } + }, + Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err), + } + } + let manager = ServerPool::new( address.clone(), user.clone(), &shard.database, client_server_map.clone(), pool_stats.clone(), + pool_auth_hash.clone(), ); let connect_timeout = match pool_config.connect_timeout { @@ -330,6 +366,12 @@ impl ConnectionPool { } assert_eq!(shards.len(), addresses.len()); + if let Some(ref _auth_hash) = *(pool_auth_hash.clone().read()) { + info!( + "Auth hash obtained from query_auth for pool {{ name: {}, user: {} }}", + pool_name, user.username + ); + } let pool = ConnectionPool { databases: shards, @@ -338,6 +380,7 @@ impl ConnectionPool { banlist: Arc::new(RwLock::new(banlist)), config_hash: new_pool_hash_value, server_info: Arc::new(RwLock::new(BytesMut::new())), + auth_hash: pool_auth_hash, settings: PoolSettings { pool_mode: pool_config.pool_mode, load_balancing_mode: pool_config.load_balancing_mode, @@ -366,6 +409,9 @@ impl ConnectionPool { .clone() .map(|regex| Regex::new(regex.as_str()).unwrap()), regex_search_limit: pool_config.regex_search_limit.unwrap_or(1000), + auth_query: pool_config.auth_query.clone(), + auth_query_user: pool_config.auth_query_user.clone(), + auth_query_password: pool_config.auth_query_password.clone(), }, validated: Arc::new(AtomicBool::new(false)), paused: Arc::new(AtomicBool::new(false)), @@ -389,7 +435,8 @@ impl ConnectionPool { Ok(()) } - /// Connect to all shards and grab server information. + /// Connect to all shards, grab server information, and possibly + /// passwords to use in client auth. /// Return server information we will pass to the clients /// when they connect. /// This also warms up the pool for clients that connect when @@ -803,6 +850,7 @@ pub struct ServerPool { database: String, client_server_map: ClientServerMap, stats: Arc, + auth_hash: Arc>>, } impl ServerPool { @@ -812,6 +860,7 @@ impl ServerPool { database: &str, client_server_map: ClientServerMap, stats: Arc, + auth_hash: Arc>>, ) -> ServerPool { ServerPool { address, @@ -819,6 +868,7 @@ impl ServerPool { database: database.to_string(), client_server_map, stats, + auth_hash, } } } @@ -847,6 +897,7 @@ impl ManageConnection for ServerPool { &self.database, self.client_server_map.clone(), stats.clone(), + self.auth_hash.clone(), ) .await { diff --git a/src/query_router.rs b/src/query_router.rs index 578c739..0ea907b 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -1110,6 +1110,9 @@ mod test { sharding_key_regex: None, shard_id_regex: None, regex_search_limit: 1000, + auth_query: None, + auth_query_password: None, + auth_query_user: None, }; let mut qr = QueryRouter::new(); assert_eq!(qr.active_role, None); @@ -1171,6 +1174,9 @@ mod test { sharding_key_regex: Some(Regex::new(r"/\* sharding_key: (\d+) \*/").unwrap()), shard_id_regex: Some(Regex::new(r"/\* shard_id: (\d+) \*/").unwrap()), regex_search_limit: 1000, + auth_query: None, + auth_query_password: None, + auth_query_user: None, }; let mut qr = QueryRouter::new(); qr.update_pool_settings(pool_settings.clone()); diff --git a/src/server.rs b/src/server.rs index d09313e..37f0e0c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,7 +1,11 @@ /// Implementation of the PostgreSQL server (database) protocol. /// Here we are pretending to the a Postgres client. use bytes::{Buf, BufMut, BytesMut}; +use fallible_iterator::FallibleIterator; use log::{debug, error, info, trace, warn}; +use parking_lot::{Mutex, RwLock}; +use postgres_protocol::message; +use std::collections::HashMap; use std::io::Read; use std::sync::Arc; use std::time::SystemTime; @@ -81,6 +85,7 @@ impl Server { database: &str, client_server_map: ClientServerMap, stats: Arc, + auth_hash: Arc>>, ) -> Result { let mut stream = match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await { @@ -106,7 +111,10 @@ impl Server { // We'll be handling multiple packets, but they will all be structured the same. // We'll loop here until this exchange is complete. - let mut scram = ScramSha256::new(&user.password); + let mut scram: Option = None; + if let Some(password) = &user.password.clone() { + scram = Some(ScramSha256::new(password)); + } loop { let code = match stream.read_u8().await { @@ -143,13 +151,40 @@ impl Server { Err(_) => return Err(Error::SocketError(format!("Error reading salt on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; - md5_password(&mut stream, &user.username, &user.password, &salt[..]) - .await?; + match &user.password { + // Using plaintext password + Some(password) => { + md5_password(&mut stream, &user.username, password, &salt[..]) + .await? + } + + // Using auth passthrough, in this case we should already have a + // hash obtained when the pool was validated. If we reach this point + // and don't have a hash, we return an error. + None => { + let option_hash = (*auth_hash.read()).clone(); + match option_hash { + Some(hash) => + md5_password_with_hash( + &mut stream, + &hash, + &salt[..], + ) + .await?, + None => + return Err(Error::AuthError(format!("Auth passthrough (auth_query) failed and no user password is set in cleartext for {{ username: {:?}, database: {:?} }}", user.username, database))) + } + } + } } AUTHENTICATION_SUCCESSFUL => (), SASL => { + if scram.is_none() { + return Err(Error::AuthError(format!("SASL auth required and not password specified, auth passthrough (auth_query) method is currently unsupported for SASL auth {{ username: {:?}, database: {:?} }}", user.username, database))); + } + debug!("Starting SASL authentication"); let sasl_len = (len - 8) as usize; let mut sasl_auth = vec![0u8; sasl_len]; @@ -165,7 +200,7 @@ impl Server { debug!("Using {}", SCRAM_SHA_256); // Generate client message. - let sasl_response = scram.message(); + let sasl_response = scram.as_mut().unwrap().message(); // SASLInitialResponse (F) let mut res = BytesMut::new(); @@ -202,7 +237,7 @@ impl Server { }; let msg = BytesMut::from(&sasl_data[..]); - let sasl_response = scram.update(&msg)?; + let sasl_response = scram.as_mut().unwrap().update(&msg)?; // SASLResponse let mut res = BytesMut::new(); @@ -222,7 +257,11 @@ impl Server { Err(_) => return Err(Error::SocketError(format!("Error reading sasl final message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; - match scram.finish(&BytesMut::from(&sasl_final[..])) { + match scram + .as_mut() + .unwrap() + .finish(&BytesMut::from(&sasl_final[..])) + { Ok(_) => { debug!("SASL authentication successful"); } @@ -696,6 +735,105 @@ impl Server { None => (), } } + + // This is so we can execute out of band queries to the server. + // The connection will be opened, the query executed and closed. + pub async fn exec_simple_query( + address: &Address, + user: &User, + query: &str, + ) -> Result, Error> { + let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new())); + + debug!("Connecting to server to obtain auth hashes."); + let mut server = Server::startup( + address, + user, + &address.database, + client_server_map, + Arc::new(ServerStats::default()), + Arc::new(RwLock::new(None)), + ) + .await?; + debug!("Connected!, sending query."); + server.send(&simple_query(query)).await?; + let mut message = server.recv().await?; + + Ok(parse_query_message(&mut message).await?) + } +} + +async fn parse_query_message(message: &mut BytesMut) -> Result, Error> { + let mut pair = Vec::::new(); + match message::backend::Message::parse(message) { + Ok(Some(message::backend::Message::RowDescription(_description))) => {} + Ok(Some(message::backend::Message::ErrorResponse(err))) => { + return Err(Error::ProtocolSyncError(format!( + "Protocol error parsing response. Err: {:?}", + err.fields() + .iterator() + .fold(String::default(), |acc, element| acc + + element.unwrap().value()) + ))) + } + Ok(_) => { + return Err(Error::ProtocolSyncError( + "Protocol error, expected Row Description.".to_string(), + )) + } + Err(err) => { + return Err(Error::ProtocolSyncError(format!( + "Protocol error parsing response. Err: {:?}", + err + ))) + } + } + + while !message.is_empty() { + match message::backend::Message::parse(message) { + Ok(postgres_message) => { + match postgres_message { + Some(message::backend::Message::DataRow(data)) => { + let buf = data.buffer(); + trace!("Data: {:?}", buf); + + for item in data.ranges().iterator() { + match item.as_ref() { + Ok(range) => match range { + Some(range) => { + pair.push(String::from_utf8_lossy(&buf[range.clone()]).to_string()); + } + None => return Err(Error::ProtocolSyncError(String::from( + "Data expected while receiving query auth data, found nothing.", + ))), + }, + Err(err) => { + return Err(Error::ProtocolSyncError(format!( + "Data error, err: {:?}", + err + ))) + } + } + } + } + Some(message::backend::Message::CommandComplete(_)) => {} + Some(message::backend::Message::ReadyForQuery(_)) => {} + _ => { + return Err(Error::ProtocolSyncError( + "Unexpected message while receiving auth query data.".to_string(), + )) + } + } + } + Err(err) => { + return Err(Error::ProtocolSyncError(format!( + "Parse error, err: {:?}", + err + ))) + } + }; + } + Ok(pair) } impl Drop for Server { diff --git a/tests/docker/docker-compose.yml b/tests/docker/docker-compose.yml index e57d852..93e9455 100644 --- a/tests/docker/docker-compose.yml +++ b/tests/docker/docker-compose.yml @@ -36,6 +36,15 @@ services: POSTGRES_PASSWORD: postgres POSTGRES_INITDB_ARGS: --auth-local=scram-sha-256 --auth-host=scram-sha-256 --auth=scram-sha-256 command: ["postgres", "-p", "9432", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-c", "pg_stat_statements.max=100000"] + pg5: + image: postgres:14 + network_mode: "service:main" + environment: + POSTGRES_USER: postgres + POSTGRES_DB: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_INITDB_ARGS: --auth-local=md5 --auth-host=md5 --auth=md5 + command: ["postgres", "-c", "shared_preload_libraries=pg_stat_statements", "-c", "pg_stat_statements.track=all", "-p", "10432"] main: build: . command: ["bash", "/app/tests/docker/run.sh"] diff --git a/tests/ruby/auth_query_spec.rb b/tests/ruby/auth_query_spec.rb new file mode 100644 index 0000000..1ac6216 --- /dev/null +++ b/tests/ruby/auth_query_spec.rb @@ -0,0 +1,215 @@ +# frozen_string_literal: true + +require_relative 'spec_helper' +require_relative 'helpers/auth_query_helper' + +describe "Auth Query" do + let(:configured_instances) {[5432, 10432]} + let(:config_user) { { 'username' => 'sharding_user', 'password' => 'sharding_user' } } + let(:pg_user) { { 'username' => 'sharding_user', 'password' => 'sharding_user' } } + let(:processes) { Helpers::AuthQuery.single_shard_auth_query(pool_name: "sharded_db", pg_user: pg_user, config_user: config_user, extra_conf: config, wait_until_ready: wait_until_ready ) } + let(:config) { {} } + let(:wait_until_ready) { true } + + after do + unless @failing_process + processes.all_databases.map(&:reset) + processes.pgcat.shutdown + end + @failing_process = false + end + + context "when auth_query is not configured" do + context 'and cleartext passwords are set' do + it "uses local passwords" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", config_user['username'], config_user['password'])) + + expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil + end + end + + context 'and cleartext passwords are not set' do + let(:config_user) { { 'username' => 'sharding_user' } } + + it "does not start because it is not possible to authenticate" do + @failing_process = true + expect { processes.pgcat }.to raise_error(StandardError, /You have to specify a user password for every pool if auth_query is not specified/) + end + end + end + + context 'when auth_query is configured' do + context 'with global configuration' do + around(:example) do |example| + + # Set up auth query + Helpers::AuthQuery.set_up_auth_query_for_user( + user: 'md5_auth_user', + password: 'secret' + ); + + example.run + + # Drop auth query support + Helpers::AuthQuery.tear_down_auth_query_for_user( + user: 'md5_auth_user', + password: 'secret' + ); + end + + context 'with correct global parameters' do + let(:config) { { 'general' => { 'auth_query' => "SELECT * FROM public.user_lookup('$1');", 'auth_query_user' => 'md5_auth_user', 'auth_query_password' => 'secret' } } } + context 'and with cleartext passwords set' do + it 'it uses local passwords' do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])) + expect(conn.exec("SELECT 1 + 2")).not_to be_nil + end + end + + context 'and with cleartext passwords not set' do + let(:config_user) { { 'username' => 'sharding_user', 'password' => 'sharding_user' } } + + it 'it uses obtained passwords' do + connection_string = processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']) + conn = PG.connect(connection_string) + expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil + end + + it 'allows passwords to be changed without closing existing connections' do + pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'])) + expect(pgconn.exec("SELECT 1 + 2")).not_to be_nil + Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD 'secret2';") + expect(pgconn.exec("SELECT 1 + 4")).not_to be_nil + Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD '#{pg_user['password']}';") + end + + it 'allows passwords to be changed and that new password is needed when reconnecting' do + pgconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'])) + expect(pgconn.exec("SELECT 1 + 2")).not_to be_nil + Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD 'secret2';") + newconn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], 'secret2')) + expect(newconn.exec("SELECT 1 + 2")).not_to be_nil + Helpers::AuthQuery.exec_in_instances(query: "ALTER USER #{pg_user['username']} WITH ENCRYPTED PASSWORD '#{pg_user['password']}';") + end + end + end + + context 'with wrong parameters' do + let(:config) { { 'general' => { 'auth_query' => 'SELECT 1', 'auth_query_user' => 'wrong_user', 'auth_query_password' => 'wrong' } } } + + context 'and with clear text passwords set' do + it "it uses local passwords" do + conn = PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])) + + expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil + end + end + + context 'and with cleartext passwords not set' do + let(:config_user) { { 'username' => 'sharding_user' } } + it "it fails to start as it cannot authenticate against servers" do + @failing_process = true + expect { PG.connect(processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password'])) }.to raise_error(StandardError, /Error trying to obtain password from auth_query/ ) + end + + context 'and we fix the issue and reload' do + let(:wait_until_ready) { false } + + it 'fails in the beginning but starts working after reloading config' do + connection_string = processes.pgcat.connection_string("sharded_db", pg_user['username'], pg_user['password']) + while !(processes.pgcat.logs =~ /Waiting for clients/) do + sleep 0.5 + end + + expect { PG.connect(connection_string)}.to raise_error(PG::ConnectionBad) + expect(processes.pgcat.logs).to match(/Error trying to obtain password from auth_query/) + + current_config = processes.pgcat.current_config + config = { 'general' => { 'auth_query' => "SELECT * FROM public.user_lookup('$1');", 'auth_query_user' => 'md5_auth_user', 'auth_query_password' => 'secret' } } + processes.pgcat.update_config(current_config.deep_merge(config)) + processes.pgcat.reload_config + + conn = nil + expect { conn = PG.connect(connection_string)}.not_to raise_error + expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil + end + end + end + end + end + + context 'with per pool configuration' do + around(:example) do |example| + + # Set up auth query + Helpers::AuthQuery.set_up_auth_query_for_user( + user: 'md5_auth_user', + password: 'secret' + ); + + Helpers::AuthQuery.set_up_auth_query_for_user( + user: 'md5_auth_user1', + password: 'secret', + database: 'shard1' + ); + + example.run + + # Tear down auth query + Helpers::AuthQuery.tear_down_auth_query_for_user( + user: 'md5_auth_user', + password: 'secret' + ); + + Helpers::AuthQuery.tear_down_auth_query_for_user( + user: 'md5_auth_user1', + password: 'secret', + database: 'shard1' + ); + end + + context 'with correct parameters' do + let(:processes) { Helpers::AuthQuery.two_pools_auth_query(pool_names: ["sharded_db0", "sharded_db1"], pg_user: pg_user, config_user: config_user, extra_conf: config ) } + let(:config) { + { 'pools' => + { + 'sharded_db0' => { + 'auth_query' => "SELECT * FROM public.user_lookup('$1');", + 'auth_query_user' => 'md5_auth_user', + 'auth_query_password' => 'secret' + }, + 'sharded_db1' => { + 'auth_query' => "SELECT * FROM public.user_lookup('$1');", + 'auth_query_user' => 'md5_auth_user1', + 'auth_query_password' => 'secret' + }, + } + } + } + + context 'and with cleartext passwords set' do + it 'it uses local passwords' do + conn = PG.connect(processes.pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password'])) + expect(conn.exec("SELECT 1 + 2")).not_to be_nil + conn = PG.connect(processes.pgcat.connection_string("sharded_db1", pg_user['username'], pg_user['password'])) + expect(conn.exec("SELECT 1 + 2")).not_to be_nil + end + end + + context 'and with cleartext passwords not set' do + let(:config_user) { { 'username' => 'sharding_user' } } + + it 'it uses obtained passwords' do + connection_string = processes.pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password']) + conn = PG.connect(connection_string) + expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil + connection_string = processes.pgcat.connection_string("sharded_db1", pg_user['username'], pg_user['password']) + conn = PG.connect(connection_string) + expect(conn.async_exec("SELECT 1 + 2")).not_to be_nil + end + end + + end + end + end +end diff --git a/tests/ruby/helpers/auth_query_helper.rb b/tests/ruby/helpers/auth_query_helper.rb new file mode 100644 index 0000000..60e8571 --- /dev/null +++ b/tests/ruby/helpers/auth_query_helper.rb @@ -0,0 +1,173 @@ +module Helpers + module AuthQuery + def self.single_shard_auth_query( + pg_user:, + config_user:, + pool_name:, + extra_conf: {}, + log_level: 'debug', + wait_until_ready: true + ) + + user = { + "pool_size" => 10, + "statement_timeout" => 0, + } + + pgcat = PgcatProcess.new(log_level) + pgcat_cfg = pgcat.current_config.deep_merge(extra_conf) + + primary = PgInstance.new(5432, pg_user["username"], pg_user["password"], "shard0") + replica = PgInstance.new(10432, pg_user["username"], pg_user["password"], "shard0") + + # Main proxy configs + pgcat_cfg["pools"] = { + "#{pool_name}" => { + "default_role" => "any", + "pool_mode" => "transaction", + "load_balancing_mode" => "random", + "primary_reads_enabled" => false, + "query_parser_enabled" => false, + "sharding_function" => "pg_bigint_hash", + "shards" => { + "0" => { + "database" => "shard0", + "servers" => [ + ["localhost", primary.port.to_s, "primary"], + ["localhost", replica.port.to_s, "replica"], + ] + }, + }, + "users" => { "0" => user.merge(config_user) } + } + } + pgcat_cfg["general"]["port"] = pgcat.port + pgcat.update_config(pgcat_cfg) + pgcat.start + + pgcat.wait_until_ready( + pgcat.connection_string( + "sharded_db", + pg_user['username'], + pg_user['password'] + ) + ) if wait_until_ready + + OpenStruct.new.tap do |struct| + struct.pgcat = pgcat + struct.primary = primary + struct.replicas = [replica] + struct.all_databases = [primary] + end + end + + def self.two_pools_auth_query( + pg_user:, + config_user:, + pool_names:, + extra_conf: {}, + log_level: 'debug' + ) + + user = { + "pool_size" => 10, + "statement_timeout" => 0, + } + + pgcat = PgcatProcess.new(log_level) + pgcat_cfg = pgcat.current_config + + primary = PgInstance.new(5432, pg_user["username"], pg_user["password"], "shard0") + replica = PgInstance.new(10432, pg_user["username"], pg_user["password"], "shard0") + + pool_template = Proc.new do |database| + { + "default_role" => "any", + "pool_mode" => "transaction", + "load_balancing_mode" => "random", + "primary_reads_enabled" => false, + "query_parser_enabled" => false, + "sharding_function" => "pg_bigint_hash", + "shards" => { + "0" => { + "database" => database, + "servers" => [ + ["localhost", primary.port.to_s, "primary"], + ["localhost", replica.port.to_s, "replica"], + ] + }, + }, + "users" => { "0" => user.merge(config_user) } + } + end + # Main proxy configs + pgcat_cfg["pools"] = { + "#{pool_names[0]}" => pool_template.call("shard0"), + "#{pool_names[1]}" => pool_template.call("shard1") + } + + pgcat_cfg["general"]["port"] = pgcat.port + pgcat.update_config(pgcat_cfg.deep_merge(extra_conf)) + pgcat.start + + pgcat.wait_until_ready(pgcat.connection_string("sharded_db0", pg_user['username'], pg_user['password'])) + + OpenStruct.new.tap do |struct| + struct.pgcat = pgcat + struct.primary = primary + struct.replicas = [replica] + struct.all_databases = [primary] + end + end + + def self.create_query_auth_function(user) + return <<-SQL +CREATE OR REPLACE FUNCTION public.user_lookup(in i_username text, out uname text, out phash text) +RETURNS record AS $$ +BEGIN + SELECT usename, passwd FROM pg_catalog.pg_shadow + WHERE usename = i_username INTO uname, phash; + RETURN; +END; +$$ LANGUAGE plpgsql SECURITY DEFINER; + +GRANT EXECUTE ON FUNCTION public.user_lookup(text) TO #{user}; +SQL + end + + def self.exec_in_instances(query:, instance_ports: [ 5432, 10432 ], database: 'postgres', user: 'postgres', password: 'postgres') + instance_ports.each do |port| + c = PG.connect("postgres://#{user}:#{password}@localhost:#{port}/#{database}") + c.exec(query) + c.close + end + end + + def self.set_up_auth_query_for_user(user:, password:, instance_ports: [ 5432, 10432 ], database: 'shard0' ) + instance_ports.each do |port| + connection = PG.connect("postgres://postgres:postgres@localhost:#{port}/#{database}") + connection.exec(self.drop_query_auth_function(user)) rescue PG::UndefinedFunction + connection.exec("DROP ROLE #{user}") rescue PG::UndefinedObject + connection.exec("CREATE ROLE #{user} ENCRYPTED PASSWORD '#{password}' LOGIN;") + connection.exec(self.create_query_auth_function(user)) + connection.close + end + end + + def self.tear_down_auth_query_for_user(user:, password:, instance_ports: [ 5432, 10432 ], database: 'shard0' ) + instance_ports.each do |port| + connection = PG.connect("postgres://postgres:postgres@localhost:#{port}/#{database}") + connection.exec(self.drop_query_auth_function(user)) rescue PG::UndefinedFunction + connection.exec("DROP ROLE #{user}") + connection.close + end + end + + def self.drop_query_auth_function(user) + return <<-SQL +REVOKE ALL ON FUNCTION public.user_lookup(text) FROM public, #{user}; +DROP FUNCTION public.user_lookup(in i_username text, out uname text, out phash text); +SQL + end + end +end diff --git a/tests/ruby/helpers/pgcat_helper.rb b/tests/ruby/helpers/pgcat_helper.rb index c4ebab7..13dc668 100644 --- a/tests/ruby/helpers/pgcat_helper.rb +++ b/tests/ruby/helpers/pgcat_helper.rb @@ -3,6 +3,13 @@ require 'ostruct' require_relative 'pgcat_process' require_relative 'pg_instance' +class ::Hash + def deep_merge(second) + merger = proc { |key, v1, v2| Hash === v1 && Hash === v2 ? v1.merge(v2, &merger) : v2 } + self.merge(second, &merger) + end +end + module Helpers module Pgcat def self.three_shard_setup(pool_name, pool_size, pool_mode="transaction", lb_mode="random", log_level="info") diff --git a/tests/ruby/helpers/pgcat_process.rb b/tests/ruby/helpers/pgcat_process.rb index 6120c99..e1dbea8 100644 --- a/tests/ruby/helpers/pgcat_process.rb +++ b/tests/ruby/helpers/pgcat_process.rb @@ -67,17 +67,21 @@ class PgcatProcess def start raise StandardError, "Process is already started" unless @pid.nil? @pid = Process.spawn(@env, @command, err: @log_filename, out: @log_filename) + Process.detach(@pid) ObjectSpace.define_finalizer(@log_filename, proc { PgcatProcess.finalize(@pid, @log_filename, @config_filename) }) return self end - def wait_until_ready + def wait_until_ready(connection_string = nil) exc = nil 10.times do - PG::connect(example_connection_string).close + Process.kill 0, @pid + PG::connect(connection_string || example_connection_string).close return self + rescue Errno::ESRCH + raise StandardError, "Process #{@pid} died. #{logs}" rescue => e exc = e sleep(0.5) @@ -108,13 +112,10 @@ class PgcatProcess "postgresql://#{username}:#{password}@0.0.0.0:#{@port}/pgcat" end - def connection_string(pool_name, username) + def connection_string(pool_name, username, password = nil) cfg = current_config - user_idx, user_obj = cfg["pools"][pool_name]["users"].detect { |k, user| user["username"] == username } - password = user_obj["password"] - - "postgresql://#{username}:#{password}@0.0.0.0:#{@port}/#{pool_name}" + "postgresql://#{username}:#{password || user_obj["password"]}@0.0.0.0:#{@port}/#{pool_name}" end def example_connection_string