mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-27 02:36:29 +00:00
Client md5 auth and clean up scram (#77)
* client md5 auth and clean up scram * add pw * add user * add user * log
This commit is contained in:
125
src/scram.rs
125
src/scram.rs
@@ -1,5 +1,6 @@
|
||||
// SCRAM authentication...largely copy/pasted from
|
||||
// https://github.com/sfackler/rust-postgres/.
|
||||
// SCRAM-SHA-256 authentication. Heavily inspired by
|
||||
// https://github.com/sfackler/rust-postgres/
|
||||
// SASL implementation.
|
||||
|
||||
use bytes::BytesMut;
|
||||
use hmac::{Hmac, Mac};
|
||||
@@ -12,6 +13,8 @@ use std::fmt::Write;
|
||||
use crate::constants::*;
|
||||
use crate::errors::Error;
|
||||
|
||||
/// Normalize a password string. Postgres
|
||||
/// passwords don't have to be UTF-8.
|
||||
fn normalize(pass: &[u8]) -> Vec<u8> {
|
||||
let pass = match std::str::from_utf8(pass) {
|
||||
Ok(pass) => pass,
|
||||
@@ -24,6 +27,8 @@ fn normalize(pass: &[u8]) -> Vec<u8> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Keep the SASL state through the exchange.
|
||||
/// It takes 3 messages to complete the authentication.
|
||||
pub struct ScramSha256 {
|
||||
password: String,
|
||||
salted_password: [u8; 32],
|
||||
@@ -33,6 +38,8 @@ pub struct ScramSha256 {
|
||||
}
|
||||
|
||||
impl ScramSha256 {
|
||||
/// Create the Scram state from a password. It'll automatically
|
||||
/// generate a nonce.
|
||||
pub fn new(password: &str) -> ScramSha256 {
|
||||
let mut rng = rand::thread_rng();
|
||||
let nonce = (0..NONCE_LENGTH)
|
||||
@@ -48,6 +55,7 @@ impl ScramSha256 {
|
||||
Self::from_nonce(password, &nonce)
|
||||
}
|
||||
|
||||
/// Used for testing.
|
||||
pub fn from_nonce(password: &str, nonce: &str) -> ScramSha256 {
|
||||
let message = BytesMut::from(&format!("{}n=,r={}", "n,,", nonce).as_bytes()[..]);
|
||||
|
||||
@@ -60,15 +68,16 @@ impl ScramSha256 {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current state of the SASL authentication.
|
||||
pub fn message(&mut self) -> BytesMut {
|
||||
self.message.clone()
|
||||
}
|
||||
|
||||
/// Update the state with message received from server.
|
||||
pub fn update(&mut self, message: &BytesMut) -> Result<BytesMut, Error> {
|
||||
let server_message = Message::parse(message)?;
|
||||
|
||||
if !server_message.nonce.starts_with(&self.nonce) {
|
||||
// trace!("Bad server nonce");
|
||||
return Err(Error::ProtocolSyncError);
|
||||
}
|
||||
|
||||
@@ -82,28 +91,39 @@ impl ScramSha256 {
|
||||
&salt,
|
||||
server_message.iterations,
|
||||
);
|
||||
|
||||
// Save for verification of final server message.
|
||||
self.salted_password = salted_password;
|
||||
|
||||
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
|
||||
.expect("HMAC is able to accept all key sizes");
|
||||
let mut hmac = match Hmac::<Sha256>::new_from_slice(&salted_password) {
|
||||
Ok(hmac) => hmac,
|
||||
Err(_) => return Err(Error::ServerError),
|
||||
};
|
||||
|
||||
hmac.update(b"Client Key");
|
||||
|
||||
let client_key = hmac.finalize().into_bytes();
|
||||
|
||||
let mut hash = Sha256::default();
|
||||
hash.update(client_key.as_slice());
|
||||
let stored_key = hash.finalize_fixed();
|
||||
|
||||
let stored_key = hash.finalize_fixed();
|
||||
let mut cbind_input = vec![];
|
||||
cbind_input.extend("n,,".as_bytes());
|
||||
|
||||
let cbind_input = base64::encode(&cbind_input);
|
||||
|
||||
self.message.clear();
|
||||
write!(
|
||||
|
||||
// Start writing the client reply.
|
||||
match write!(
|
||||
&mut self.message,
|
||||
"c={},r={}",
|
||||
cbind_input, server_message.nonce
|
||||
)
|
||||
.unwrap();
|
||||
) {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::ServerError),
|
||||
};
|
||||
|
||||
let auth_message = format!(
|
||||
"n=,r={},{},{}",
|
||||
@@ -112,23 +132,32 @@ impl ScramSha256 {
|
||||
String::from_utf8_lossy(&self.message[..])
|
||||
);
|
||||
|
||||
let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
|
||||
.expect("HMAC is able to accept all key sizes");
|
||||
let mut hmac = match Hmac::<Sha256>::new_from_slice(&stored_key) {
|
||||
Ok(hmac) => hmac,
|
||||
Err(_) => return Err(Error::ServerError),
|
||||
};
|
||||
hmac.update(auth_message.as_bytes());
|
||||
|
||||
// Save the auth message for server final message verification.
|
||||
self.auth_message = auth_message;
|
||||
|
||||
let client_signature = hmac.finalize().into_bytes();
|
||||
|
||||
// Sign the client proof.
|
||||
let mut client_proof = client_key;
|
||||
for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
|
||||
*proof ^= signature;
|
||||
}
|
||||
|
||||
write!(&mut self.message, ",p={}", base64::encode(&*client_proof)).unwrap();
|
||||
|
||||
self.auth_message = auth_message;
|
||||
match write!(&mut self.message, ",p={}", base64::encode(&*client_proof)) {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::ServerError),
|
||||
};
|
||||
|
||||
Ok(self.message.clone())
|
||||
}
|
||||
|
||||
/// Verify final server message.
|
||||
pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> {
|
||||
let final_message = FinalMessage::parse(message)?;
|
||||
|
||||
@@ -137,13 +166,17 @@ impl ScramSha256 {
|
||||
Err(_) => return Err(Error::ProtocolSyncError),
|
||||
};
|
||||
|
||||
let mut hmac = Hmac::<Sha256>::new_from_slice(&self.salted_password)
|
||||
.expect("HMAC is able to accept all key sizes");
|
||||
let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
|
||||
Ok(hmac) => hmac,
|
||||
Err(_) => return Err(Error::ServerError),
|
||||
};
|
||||
hmac.update(b"Server Key");
|
||||
let server_key = hmac.finalize().into_bytes();
|
||||
|
||||
let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
|
||||
.expect("HMAC is able to accept all key sizes");
|
||||
let mut hmac = match Hmac::<Sha256>::new_from_slice(&server_key) {
|
||||
Ok(hmac) => hmac,
|
||||
Err(_) => return Err(Error::ServerError),
|
||||
};
|
||||
hmac.update(self.auth_message.as_bytes());
|
||||
|
||||
match hmac.verify_slice(&verifier) {
|
||||
@@ -152,7 +185,7 @@ impl ScramSha256 {
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/sfackler/rust-postgres/blob/c3a029e60c1c0bd0be947049859b8fa5bd5ac220/postgres-protocol/src/authentication/sasl.rs#L35
|
||||
/// Hash the password with the salt i-times.
|
||||
fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] {
|
||||
let mut hmac =
|
||||
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
|
||||
@@ -176,7 +209,7 @@ impl ScramSha256 {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
/// Parse the server challenge.
|
||||
struct Message {
|
||||
nonce: String,
|
||||
salt: String,
|
||||
@@ -184,47 +217,21 @@ struct Message {
|
||||
}
|
||||
|
||||
impl Message {
|
||||
/// Parse the server SASL challenge.
|
||||
fn parse(message: &BytesMut) -> Result<Message, Error> {
|
||||
if !message.starts_with(b"r=") {
|
||||
let parts = String::from_utf8_lossy(&message[..])
|
||||
.split(",")
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
if parts.len() != 3 {
|
||||
return Err(Error::ProtocolSyncError);
|
||||
}
|
||||
|
||||
let mut i = 2;
|
||||
|
||||
while message[i] != b',' && i < message.len() {
|
||||
i += 1;
|
||||
}
|
||||
|
||||
let nonce = String::from_utf8_lossy(&message[2..i]).to_string();
|
||||
|
||||
// Skip the ,
|
||||
i += 1;
|
||||
|
||||
if !&message[i..].starts_with(b"s=") {
|
||||
return Err(Error::ProtocolSyncError);
|
||||
}
|
||||
|
||||
// Skip the s=
|
||||
i += 2;
|
||||
|
||||
let s = i;
|
||||
while message[i] != b',' && i < message.len() {
|
||||
i += 1;
|
||||
}
|
||||
|
||||
let salt = String::from_utf8_lossy(&message[s..i]).to_string();
|
||||
|
||||
// Skip the ,
|
||||
i += 1;
|
||||
|
||||
if !&message[i..].starts_with(b"i=") {
|
||||
return Err(Error::ProtocolSyncError);
|
||||
}
|
||||
|
||||
i += 2;
|
||||
|
||||
let iterations = match String::from_utf8_lossy(&message[i..]).parse::<u32>() {
|
||||
Ok(it) => it,
|
||||
let nonce = str::replace(&parts[0], "r=", "");
|
||||
let salt = str::replace(&parts[1], "s=", "");
|
||||
let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
|
||||
Ok(iterations) => iterations,
|
||||
Err(_) => return Err(Error::ProtocolSyncError),
|
||||
};
|
||||
|
||||
@@ -236,13 +243,15 @@ impl Message {
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse server final validation message.
|
||||
struct FinalMessage {
|
||||
value: String,
|
||||
}
|
||||
|
||||
impl FinalMessage {
|
||||
/// Parse the server final validation message.
|
||||
pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
|
||||
if !message.starts_with(b"v=") {
|
||||
if !message.starts_with(b"v=") || message.len() < 4 {
|
||||
return Err(Error::ProtocolSyncError);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user