diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index f325d16..66fcb79 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -13,6 +13,7 @@ function start_pgcat() { # Setup the database with shards and user PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 5432 -U postgres -f tests/sharding/query_routing_setup.sql + PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard0 -i PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard1 -i PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard2 -i @@ -30,26 +31,28 @@ toxiproxy-cli create -l 127.0.0.1:5433 -u 127.0.0.1:5432 postgres_replica start_pgcat "info" +export PGPASSWORD=sharding_user + # pgbench test -pgbench -i -h 127.0.0.1 -p 6432 -pgbench -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol simple -f tests/pgbench/simple.sql -pgbench -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol extended +pgbench -U sharding_user -i -h 127.0.0.1 -p 6432 +pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol simple -f tests/pgbench/simple.sql +pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol extended # COPY TO STDOUT test -psql -h 127.0.0.1 -p 6432 -c 'COPY (SELECT * FROM pgbench_accounts LIMIT 15) TO STDOUT;' > /dev/null +psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'COPY (SELECT * FROM pgbench_accounts LIMIT 15) TO STDOUT;' > /dev/null # Query cancellation test -(psql -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(5)' || true) & +(psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(5)' || true) & killall psql -s SIGINT # Sharding insert -psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_insert.sql +psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_insert.sql # Sharding select -psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_select.sql > /dev/null +psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_select.sql > /dev/null # Replica/primary selection & more sharding tests -psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_primary_replica.sql > /dev/null +psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_primary_replica.sql > /dev/null # # ActiveRecord tests @@ -61,15 +64,15 @@ cd tests/ruby && \ cd ../.. # Admin tests -psql -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null -psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null -psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null -psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null -psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW LISTS' > /dev/null -psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW POOLS' > /dev/null -psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW VERSION' > /dev/null -psql -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev/null # will ignore -(! psql -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null) +psql -U sharding_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null +psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null +psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null +psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null +psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW LISTS' > /dev/null +psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW POOLS' > /dev/null +psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW VERSION' > /dev/null +psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev/null # will ignore +(! psql -U sharding_user -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null) # Start PgCat in debug to demonstrate failover better start_pgcat "trace" @@ -79,7 +82,7 @@ toxiproxy-cli toxic add -t latency -a latency=300 postgres_replica sleep 1 # Note the failover in the logs -timeout 5 psql -e -h 127.0.0.1 -p 6432 <<-EOF +timeout 5 psql -U sharding_user -e -h 127.0.0.1 -p 6432 <<-EOF SELECT 1; SELECT 1; SELECT 1; @@ -97,7 +100,7 @@ sed -i 's/pool_mode = "transaction"/pool_mode = "session"/' pgcat.toml kill -SIGHUP $(pgrep pgcat) # Prepared statements that will only work in session mode -pgbench -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol prepared +pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol prepared # Attempt clean shut down killall pgcat -s SIGINT diff --git a/Cargo.lock b/Cargo.lock index 51bbc6e..668f421 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -368,7 +368,7 @@ dependencies = [ [[package]] name = "pgcat" -version = "0.1.0-beta2" +version = "0.2.0-beta1" dependencies = [ "arc-swap", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index ae18d7b..924b9cb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pgcat" -version = "0.2.0-beta1" +version = "0.2.1-beta1" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/src/client.rs b/src/client.rs index 2c96ccc..08b7049 100644 --- a/src/client.rs +++ b/src/client.rs @@ -72,9 +72,9 @@ impl Client { server_info: BytesMut, stats: Reporter, ) -> Result { - let config = get_config(); + let config = get_config().clone(); let transaction_mode = config.general.pool_mode.starts_with("t"); - drop(config); + // drop(config); loop { trace!("Waiting for StartupMessage"); @@ -108,14 +108,51 @@ impl Client { // Regular startup message. PROTOCOL_VERSION_NUMBER => { trace!("Got StartupMessage"); - - // TODO: perform actual auth. let parameters = parse_startup(bytes.clone())?; // Generate random backend ID and secret key let process_id: i32 = rand::random(); let secret_key: i32 = rand::random(); + // Perform MD5 authentication. + // TODO: Add SASL support. + let salt = md5_challenge(&mut stream).await?; + + let code = match stream.read_u8().await { + Ok(p) => p, + Err(_) => return Err(Error::SocketError), + }; + + // PasswordMessage + if code as char != 'p' { + debug!("Expected p, got {}", code as char); + return Err(Error::ProtocolSyncError); + } + + let len = match stream.read_i32().await { + Ok(len) => len, + Err(_) => return Err(Error::SocketError), + }; + + let mut password_response = vec![0u8; (len - 4) as usize]; + + match stream.read_exact(&mut password_response).await { + Ok(_) => (), + Err(_) => return Err(Error::SocketError), + }; + + // Compare server and client hashes. + let password_hash = + md5_hash_password(&config.user.name, &config.user.password, &salt); + + if password_hash != password_response { + debug!("Password authentication failed"); + wrong_password(&mut stream, &config.user.name).await?; + return Err(Error::ClientError); + } + + debug!("Password authentication successful"); + auth_ok(&mut stream).await?; write_all(&mut stream, server_info).await?; backend_key_data(&mut stream, process_id, secret_key).await?; diff --git a/src/errors.rs b/src/errors.rs index b42d321..b07d508 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -9,4 +9,5 @@ pub enum Error { ServerError, BadConfig, AllServersDown, + ClientError, } diff --git a/src/messages.rs b/src/messages.rs index 3420bb4..993545b 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -40,6 +40,26 @@ pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> { Ok(write_all(stream, auth_ok).await?) } +/// Generate md5 password challenge. +pub async fn md5_challenge(stream: &mut TcpStream) -> Result<[u8; 4], Error> { + // let mut rng = rand::thread_rng(); + let salt: [u8; 4] = [ + rand::random(), + rand::random(), + rand::random(), + rand::random(), + ]; + + let mut res = BytesMut::new(); + res.put_u8(b'R'); + res.put_i32(12); + res.put_i32(5); // MD5 + res.put_slice(&salt[..]); + + write_all(stream, res).await?; + Ok(salt) +} + /// Give the client the process_id and secret we generated /// used in query cancellation. pub async fn backend_key_data( @@ -160,14 +180,8 @@ pub fn parse_startup(bytes: BytesMut) -> Result, Error> Ok(result) } -/// Send password challenge response to the server. -/// This is the MD5 challenge. -pub async fn md5_password( - stream: &mut TcpStream, - user: &str, - password: &str, - salt: &[u8], -) -> Result<(), Error> { +/// Create md5 password hash given a salt. +pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec { let mut md5 = Md5::new(); // First pass @@ -186,6 +200,19 @@ pub async fn md5_password( .collect::>(); password.push(0); + password +} + +/// Send password challenge response to the server. +/// This is the MD5 challenge. +pub async fn md5_password( + stream: &mut TcpStream, + user: &str, + password: &str, + salt: &[u8], +) -> Result<(), Error> { + let password = md5_hash_password(user, password, salt); + let mut message = BytesMut::with_capacity(password.len() as usize + 5); message.put_u8(b'p'); @@ -264,6 +291,39 @@ pub async fn error_response(stream: &mut OwnedWriteHalf, message: &str) -> Resul Ok(write_all_half(stream, res).await?) } +pub async fn wrong_password(stream: &mut TcpStream, user: &str) -> Result<(), Error> { + let mut error = BytesMut::new(); + + // Error level + error.put_u8(b'S'); + error.put_slice(&b"FATAL\0"[..]); + + // Error level (non-translatable) + error.put_u8(b'V'); + error.put_slice(&b"FATAL\0"[..]); + + // Error code: not sure how much this matters. + error.put_u8(b'C'); + error.put_slice(&b"28P01\0"[..]); // system_error, see Appendix A. + + // The short error message. + error.put_u8(b'M'); + error.put_slice(&format!("password authentication failed for user \"{}\"\0", user).as_bytes()); + + // No more fields follow. + error.put_u8(0); + + // Compose the two message reply. + let mut res = BytesMut::new(); + + res.put_u8(b'E'); + res.put_i32(error.len() as i32 + 4); + + res.put(error); + + write_all(stream, res).await +} + /// Respond to a SHOW SHARD command. pub async fn show_response( stream: &mut OwnedWriteHalf, diff --git a/src/scram.rs b/src/scram.rs index 58096fa..514ed7a 100644 --- a/src/scram.rs +++ b/src/scram.rs @@ -1,5 +1,6 @@ -// SCRAM authentication...largely copy/pasted from -// https://github.com/sfackler/rust-postgres/. +// SCRAM-SHA-256 authentication. Heavily inspired by +// https://github.com/sfackler/rust-postgres/ +// SASL implementation. use bytes::BytesMut; use hmac::{Hmac, Mac}; @@ -12,6 +13,8 @@ use std::fmt::Write; use crate::constants::*; use crate::errors::Error; +/// Normalize a password string. Postgres +/// passwords don't have to be UTF-8. fn normalize(pass: &[u8]) -> Vec { let pass = match std::str::from_utf8(pass) { Ok(pass) => pass, @@ -24,6 +27,8 @@ fn normalize(pass: &[u8]) -> Vec { } } +/// Keep the SASL state through the exchange. +/// It takes 3 messages to complete the authentication. pub struct ScramSha256 { password: String, salted_password: [u8; 32], @@ -33,6 +38,8 @@ pub struct ScramSha256 { } impl ScramSha256 { + /// Create the Scram state from a password. It'll automatically + /// generate a nonce. pub fn new(password: &str) -> ScramSha256 { let mut rng = rand::thread_rng(); let nonce = (0..NONCE_LENGTH) @@ -48,6 +55,7 @@ impl ScramSha256 { Self::from_nonce(password, &nonce) } + /// Used for testing. pub fn from_nonce(password: &str, nonce: &str) -> ScramSha256 { let message = BytesMut::from(&format!("{}n=,r={}", "n,,", nonce).as_bytes()[..]); @@ -60,15 +68,16 @@ impl ScramSha256 { } } + /// Get the current state of the SASL authentication. pub fn message(&mut self) -> BytesMut { self.message.clone() } + /// Update the state with message received from server. pub fn update(&mut self, message: &BytesMut) -> Result { let server_message = Message::parse(message)?; if !server_message.nonce.starts_with(&self.nonce) { - // trace!("Bad server nonce"); return Err(Error::ProtocolSyncError); } @@ -82,28 +91,39 @@ impl ScramSha256 { &salt, server_message.iterations, ); + + // Save for verification of final server message. self.salted_password = salted_password; - let mut hmac = Hmac::::new_from_slice(&salted_password) - .expect("HMAC is able to accept all key sizes"); + let mut hmac = match Hmac::::new_from_slice(&salted_password) { + Ok(hmac) => hmac, + Err(_) => return Err(Error::ServerError), + }; + hmac.update(b"Client Key"); + let client_key = hmac.finalize().into_bytes(); let mut hash = Sha256::default(); hash.update(client_key.as_slice()); - let stored_key = hash.finalize_fixed(); + let stored_key = hash.finalize_fixed(); let mut cbind_input = vec![]; cbind_input.extend("n,,".as_bytes()); + let cbind_input = base64::encode(&cbind_input); self.message.clear(); - write!( + + // Start writing the client reply. + match write!( &mut self.message, "c={},r={}", cbind_input, server_message.nonce - ) - .unwrap(); + ) { + Ok(_) => (), + Err(_) => return Err(Error::ServerError), + }; let auth_message = format!( "n=,r={},{},{}", @@ -112,23 +132,32 @@ impl ScramSha256 { String::from_utf8_lossy(&self.message[..]) ); - let mut hmac = Hmac::::new_from_slice(&stored_key) - .expect("HMAC is able to accept all key sizes"); + let mut hmac = match Hmac::::new_from_slice(&stored_key) { + Ok(hmac) => hmac, + Err(_) => return Err(Error::ServerError), + }; hmac.update(auth_message.as_bytes()); + + // Save the auth message for server final message verification. + self.auth_message = auth_message; + let client_signature = hmac.finalize().into_bytes(); + // Sign the client proof. let mut client_proof = client_key; for (proof, signature) in client_proof.iter_mut().zip(client_signature) { *proof ^= signature; } - write!(&mut self.message, ",p={}", base64::encode(&*client_proof)).unwrap(); - - self.auth_message = auth_message; + match write!(&mut self.message, ",p={}", base64::encode(&*client_proof)) { + Ok(_) => (), + Err(_) => return Err(Error::ServerError), + }; Ok(self.message.clone()) } + /// Verify final server message. pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> { let final_message = FinalMessage::parse(message)?; @@ -137,13 +166,17 @@ impl ScramSha256 { Err(_) => return Err(Error::ProtocolSyncError), }; - let mut hmac = Hmac::::new_from_slice(&self.salted_password) - .expect("HMAC is able to accept all key sizes"); + let mut hmac = match Hmac::::new_from_slice(&self.salted_password) { + Ok(hmac) => hmac, + Err(_) => return Err(Error::ServerError), + }; hmac.update(b"Server Key"); let server_key = hmac.finalize().into_bytes(); - let mut hmac = Hmac::::new_from_slice(&server_key) - .expect("HMAC is able to accept all key sizes"); + let mut hmac = match Hmac::::new_from_slice(&server_key) { + Ok(hmac) => hmac, + Err(_) => return Err(Error::ServerError), + }; hmac.update(self.auth_message.as_bytes()); match hmac.verify_slice(&verifier) { @@ -152,7 +185,7 @@ impl ScramSha256 { } } - // https://github.com/sfackler/rust-postgres/blob/c3a029e60c1c0bd0be947049859b8fa5bd5ac220/postgres-protocol/src/authentication/sasl.rs#L35 + /// Hash the password with the salt i-times. fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] { let mut hmac = Hmac::::new_from_slice(str).expect("HMAC is able to accept all key sizes"); @@ -176,7 +209,7 @@ impl ScramSha256 { } } -#[derive(Default, Debug)] +/// Parse the server challenge. struct Message { nonce: String, salt: String, @@ -184,47 +217,21 @@ struct Message { } impl Message { + /// Parse the server SASL challenge. fn parse(message: &BytesMut) -> Result { - if !message.starts_with(b"r=") { + let parts = String::from_utf8_lossy(&message[..]) + .split(",") + .map(|s| s.to_string()) + .collect::>(); + + if parts.len() != 3 { return Err(Error::ProtocolSyncError); } - let mut i = 2; - - while message[i] != b',' && i < message.len() { - i += 1; - } - - let nonce = String::from_utf8_lossy(&message[2..i]).to_string(); - - // Skip the , - i += 1; - - if !&message[i..].starts_with(b"s=") { - return Err(Error::ProtocolSyncError); - } - - // Skip the s= - i += 2; - - let s = i; - while message[i] != b',' && i < message.len() { - i += 1; - } - - let salt = String::from_utf8_lossy(&message[s..i]).to_string(); - - // Skip the , - i += 1; - - if !&message[i..].starts_with(b"i=") { - return Err(Error::ProtocolSyncError); - } - - i += 2; - - let iterations = match String::from_utf8_lossy(&message[i..]).parse::() { - Ok(it) => it, + let nonce = str::replace(&parts[0], "r=", ""); + let salt = str::replace(&parts[1], "s=", ""); + let iterations = match str::replace(&parts[2], "i=", "").parse::() { + Ok(iterations) => iterations, Err(_) => return Err(Error::ProtocolSyncError), }; @@ -236,13 +243,15 @@ impl Message { } } +/// Parse server final validation message. struct FinalMessage { value: String, } impl FinalMessage { + /// Parse the server final validation message. pub fn parse(message: &BytesMut) -> Result { - if !message.starts_with(b"v=") { + if !message.starts_with(b"v=") || message.len() < 4 { return Err(Error::ProtocolSyncError); } diff --git a/src/server.rs b/src/server.rs index 3670af9..b20d153 100644 --- a/src/server.rs +++ b/src/server.rs @@ -137,6 +137,7 @@ impl Server { debug!("Starting SASL authentication"); let sasl_len = (len - 8) as usize; let mut sasl_auth = vec![0u8; sasl_len]; + match stream.read_exact(&mut sasl_auth).await { Ok(_) => (), Err(_) => return Err(Error::SocketError), @@ -147,16 +148,22 @@ impl Server { if sasl_type == SCRAM_SHA_256 { debug!("Using {}", SCRAM_SHA_256); - // Send client message + // Generate client message. let sasl_response = scram.message(); + + // SASLInitialResponse (F) let mut res = BytesMut::new(); res.put_u8(b'p'); + + // length + String length + length + length of sasl response res.put_i32( - 4 + SCRAM_SHA_256.len() as i32 - + 1 - + sasl_response.len() as i32 - + 4, + 4 // i32 size + + SCRAM_SHA_256.len() as i32 // length of SASL version string, + + 1 // Null terminator for the SASL version string, + + 4 // i32 size + + sasl_response.len() as i32, // length of SASL response ); + res.put_slice(&format!("{}\0", SCRAM_SHA_256).as_bytes()[..]); res.put_i32(sasl_response.len() as i32); res.put(sasl_response); @@ -181,6 +188,7 @@ impl Server { let msg = BytesMut::from(&sasl_data[..]); let sasl_response = scram.update(&msg)?; + // SASLResponse let mut res = BytesMut::new(); res.put_u8(b'p'); res.put_i32(4 + sasl_response.len() as i32);