mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 09:26:30 +00:00
Compare commits
2 Commits
v0.2.0-bet
...
v0.2.1-bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d865d9f9d8 | ||
|
|
d3310a62c2 |
@@ -13,6 +13,7 @@ function start_pgcat() {
|
||||
|
||||
# Setup the database with shards and user
|
||||
PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 5432 -U postgres -f tests/sharding/query_routing_setup.sql
|
||||
|
||||
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard0 -i
|
||||
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard1 -i
|
||||
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard2 -i
|
||||
@@ -30,26 +31,28 @@ toxiproxy-cli create -l 127.0.0.1:5433 -u 127.0.0.1:5432 postgres_replica
|
||||
|
||||
start_pgcat "info"
|
||||
|
||||
export PGPASSWORD=sharding_user
|
||||
|
||||
# pgbench test
|
||||
pgbench -i -h 127.0.0.1 -p 6432
|
||||
pgbench -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol simple -f tests/pgbench/simple.sql
|
||||
pgbench -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol extended
|
||||
pgbench -U sharding_user -i -h 127.0.0.1 -p 6432
|
||||
pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol simple -f tests/pgbench/simple.sql
|
||||
pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol extended
|
||||
|
||||
# COPY TO STDOUT test
|
||||
psql -h 127.0.0.1 -p 6432 -c 'COPY (SELECT * FROM pgbench_accounts LIMIT 15) TO STDOUT;' > /dev/null
|
||||
psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'COPY (SELECT * FROM pgbench_accounts LIMIT 15) TO STDOUT;' > /dev/null
|
||||
|
||||
# Query cancellation test
|
||||
(psql -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(5)' || true) &
|
||||
(psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(5)' || true) &
|
||||
killall psql -s SIGINT
|
||||
|
||||
# Sharding insert
|
||||
psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_insert.sql
|
||||
psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_insert.sql
|
||||
|
||||
# Sharding select
|
||||
psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_select.sql > /dev/null
|
||||
psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_select.sql > /dev/null
|
||||
|
||||
# Replica/primary selection & more sharding tests
|
||||
psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_primary_replica.sql > /dev/null
|
||||
psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_primary_replica.sql > /dev/null
|
||||
|
||||
#
|
||||
# ActiveRecord tests
|
||||
@@ -61,15 +64,15 @@ cd tests/ruby && \
|
||||
cd ../..
|
||||
|
||||
# Admin tests
|
||||
psql -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
|
||||
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null
|
||||
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null
|
||||
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null
|
||||
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW LISTS' > /dev/null
|
||||
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW POOLS' > /dev/null
|
||||
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW VERSION' > /dev/null
|
||||
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev/null # will ignore
|
||||
(! psql -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null)
|
||||
psql -U sharding_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
|
||||
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null
|
||||
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null
|
||||
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null
|
||||
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW LISTS' > /dev/null
|
||||
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW POOLS' > /dev/null
|
||||
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW VERSION' > /dev/null
|
||||
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev/null # will ignore
|
||||
(! psql -U sharding_user -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null)
|
||||
|
||||
# Start PgCat in debug to demonstrate failover better
|
||||
start_pgcat "trace"
|
||||
@@ -79,7 +82,7 @@ toxiproxy-cli toxic add -t latency -a latency=300 postgres_replica
|
||||
sleep 1
|
||||
|
||||
# Note the failover in the logs
|
||||
timeout 5 psql -e -h 127.0.0.1 -p 6432 <<-EOF
|
||||
timeout 5 psql -U sharding_user -e -h 127.0.0.1 -p 6432 <<-EOF
|
||||
SELECT 1;
|
||||
SELECT 1;
|
||||
SELECT 1;
|
||||
@@ -97,7 +100,7 @@ sed -i 's/pool_mode = "transaction"/pool_mode = "session"/' pgcat.toml
|
||||
kill -SIGHUP $(pgrep pgcat)
|
||||
|
||||
# Prepared statements that will only work in session mode
|
||||
pgbench -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol prepared
|
||||
pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol prepared
|
||||
|
||||
# Attempt clean shut down
|
||||
killall pgcat -s SIGINT
|
||||
|
||||
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -368,7 +368,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "pgcat"
|
||||
version = "0.1.0-beta2"
|
||||
version = "0.2.0-beta1"
|
||||
dependencies = [
|
||||
"arc-swap",
|
||||
"async-trait",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "pgcat"
|
||||
version = "0.2.0-beta1"
|
||||
version = "0.2.1-beta1"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
26
README.md
26
README.md
@@ -9,19 +9,19 @@ PostgreSQL pooler (like PgBouncer) with sharding, load balancing and failover su
|
||||
**Beta**: looking for beta testers, see [#35](https://github.com/levkk/pgcat/issues/35).
|
||||
|
||||
## Features
|
||||
| **Feature** | **Status** | **Comments** |
|
||||
|--------------------------------|-----------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| Transaction pooling | :white_check_mark: | Identical to PgBouncer. |
|
||||
| Session pooling | :white_check_mark: | Identical to PgBouncer. |
|
||||
| `COPY` support | :white_check_mark: | Both `COPY TO` and `COPY FROM` are supported. |
|
||||
| Query cancellation | :white_check_mark: | Supported both in transaction and session pooling modes. |
|
||||
| Load balancing of read queries | :white_check_mark: | Using round-robin between replicas. Primary is included when `primary_reads_enabled` is enabled (default). |
|
||||
| Sharding | :white_check_mark: | Transactions are sharded using `SET SHARD TO` and `SET SHARDING KEY TO` syntax extensions; see examples below. |
|
||||
| Failover | :white_check_mark: | Replicas are tested with a health check. If a health check fails, remaining replicas are attempted; see below for algorithm description and examples. |
|
||||
| Statistics | :white_check_mark: | Statistics available in the admin database (`pgcat` and `pgbouncer`) with `SHOW STATS`, `SHOW POOLS` and others. |
|
||||
| Live configuration reloading | :white_check_mark: | Reload supported settings with a `SIGHUP` to the process, e.g. `kill -s SIGHUP $(pgrep pgcat)` or `RELOAD` query issued to the admin database. |
|
||||
| Client authentication | :x: :wrench: | On the roadmap; currently all clients are allowed to connect and one user is used to connect to Postgres. |
|
||||
| Admin database | :white_check_mark: | The admin database, similar to PgBouncer's, allows to query for statistics and reload the configuration. |
|
||||
| **Feature** | **Status** | **Comments** |
|
||||
|--------------------------------|-----------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| Transaction pooling | :white_check_mark: | Identical to PgBouncer. |
|
||||
| Session pooling | :white_check_mark: | Identical to PgBouncer. |
|
||||
| `COPY` support | :white_check_mark: | Both `COPY TO` and `COPY FROM` are supported. |
|
||||
| Query cancellation | :white_check_mark: | Supported both in transaction and session pooling modes. |
|
||||
| Load balancing of read queries | :white_check_mark: | Using round-robin between replicas. Primary is included when `primary_reads_enabled` is enabled (default). |
|
||||
| Sharding | :white_check_mark: | Transactions are sharded using `SET SHARD TO` and `SET SHARDING KEY TO` syntax extensions; see examples below. |
|
||||
| Failover | :white_check_mark: | Replicas are tested with a health check. If a health check fails, remaining replicas are attempted; see below for algorithm description and examples. |
|
||||
| Statistics | :white_check_mark: | Statistics available in the admin database (`pgcat` and `pgbouncer`) with `SHOW STATS`, `SHOW POOLS` and others. |
|
||||
| Live configuration reloading | :white_check_mark: | Reload supported settings with a `SIGHUP` to the process, e.g. `kill -s SIGHUP $(pgrep pgcat)` or `RELOAD` query issued to the admin database. |
|
||||
| Client authentication | :white_check_mark: :wrench: | MD5 password authentication is supported, SCRAM is on the roadmap; one user is used to connect to Postgres with both SCRAM and MD5 supported. |
|
||||
| Admin database | :white_check_mark: | The admin database, similar to PgBouncer's, allows to query for statistics and reload the configuration. |
|
||||
|
||||
## Deployment
|
||||
|
||||
|
||||
@@ -72,9 +72,9 @@ impl Client {
|
||||
server_info: BytesMut,
|
||||
stats: Reporter,
|
||||
) -> Result<Client, Error> {
|
||||
let config = get_config();
|
||||
let config = get_config().clone();
|
||||
let transaction_mode = config.general.pool_mode.starts_with("t");
|
||||
drop(config);
|
||||
// drop(config);
|
||||
loop {
|
||||
trace!("Waiting for StartupMessage");
|
||||
|
||||
@@ -108,14 +108,51 @@ impl Client {
|
||||
// Regular startup message.
|
||||
PROTOCOL_VERSION_NUMBER => {
|
||||
trace!("Got StartupMessage");
|
||||
|
||||
// TODO: perform actual auth.
|
||||
let parameters = parse_startup(bytes.clone())?;
|
||||
|
||||
// Generate random backend ID and secret key
|
||||
let process_id: i32 = rand::random();
|
||||
let secret_key: i32 = rand::random();
|
||||
|
||||
// Perform MD5 authentication.
|
||||
// TODO: Add SASL support.
|
||||
let salt = md5_challenge(&mut stream).await?;
|
||||
|
||||
let code = match stream.read_u8().await {
|
||||
Ok(p) => p,
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
};
|
||||
|
||||
// PasswordMessage
|
||||
if code as char != 'p' {
|
||||
debug!("Expected p, got {}", code as char);
|
||||
return Err(Error::ProtocolSyncError);
|
||||
}
|
||||
|
||||
let len = match stream.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
};
|
||||
|
||||
let mut password_response = vec![0u8; (len - 4) as usize];
|
||||
|
||||
match stream.read_exact(&mut password_response).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
};
|
||||
|
||||
// Compare server and client hashes.
|
||||
let password_hash =
|
||||
md5_hash_password(&config.user.name, &config.user.password, &salt);
|
||||
|
||||
if password_hash != password_response {
|
||||
debug!("Password authentication failed");
|
||||
wrong_password(&mut stream, &config.user.name).await?;
|
||||
return Err(Error::ClientError);
|
||||
}
|
||||
|
||||
debug!("Password authentication successful");
|
||||
|
||||
auth_ok(&mut stream).await?;
|
||||
write_all(&mut stream, server_info).await?;
|
||||
backend_key_data(&mut stream, process_id, secret_key).await?;
|
||||
|
||||
@@ -9,4 +9,5 @@ pub enum Error {
|
||||
ServerError,
|
||||
BadConfig,
|
||||
AllServersDown,
|
||||
ClientError,
|
||||
}
|
||||
|
||||
@@ -40,6 +40,26 @@ pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
|
||||
Ok(write_all(stream, auth_ok).await?)
|
||||
}
|
||||
|
||||
/// Generate md5 password challenge.
|
||||
pub async fn md5_challenge(stream: &mut TcpStream) -> Result<[u8; 4], Error> {
|
||||
// let mut rng = rand::thread_rng();
|
||||
let salt: [u8; 4] = [
|
||||
rand::random(),
|
||||
rand::random(),
|
||||
rand::random(),
|
||||
rand::random(),
|
||||
];
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
res.put_u8(b'R');
|
||||
res.put_i32(12);
|
||||
res.put_i32(5); // MD5
|
||||
res.put_slice(&salt[..]);
|
||||
|
||||
write_all(stream, res).await?;
|
||||
Ok(salt)
|
||||
}
|
||||
|
||||
/// Give the client the process_id and secret we generated
|
||||
/// used in query cancellation.
|
||||
pub async fn backend_key_data(
|
||||
@@ -160,14 +180,8 @@ pub fn parse_startup(bytes: BytesMut) -> Result<HashMap<String, String>, Error>
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Send password challenge response to the server.
|
||||
/// This is the MD5 challenge.
|
||||
pub async fn md5_password(
|
||||
stream: &mut TcpStream,
|
||||
user: &str,
|
||||
password: &str,
|
||||
salt: &[u8],
|
||||
) -> Result<(), Error> {
|
||||
/// Create md5 password hash given a salt.
|
||||
pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
|
||||
let mut md5 = Md5::new();
|
||||
|
||||
// First pass
|
||||
@@ -186,6 +200,19 @@ pub async fn md5_password(
|
||||
.collect::<Vec<u8>>();
|
||||
password.push(0);
|
||||
|
||||
password
|
||||
}
|
||||
|
||||
/// Send password challenge response to the server.
|
||||
/// This is the MD5 challenge.
|
||||
pub async fn md5_password(
|
||||
stream: &mut TcpStream,
|
||||
user: &str,
|
||||
password: &str,
|
||||
salt: &[u8],
|
||||
) -> Result<(), Error> {
|
||||
let password = md5_hash_password(user, password, salt);
|
||||
|
||||
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
|
||||
|
||||
message.put_u8(b'p');
|
||||
@@ -264,6 +291,39 @@ pub async fn error_response(stream: &mut OwnedWriteHalf, message: &str) -> Resul
|
||||
Ok(write_all_half(stream, res).await?)
|
||||
}
|
||||
|
||||
pub async fn wrong_password(stream: &mut TcpStream, user: &str) -> Result<(), Error> {
|
||||
let mut error = BytesMut::new();
|
||||
|
||||
// Error level
|
||||
error.put_u8(b'S');
|
||||
error.put_slice(&b"FATAL\0"[..]);
|
||||
|
||||
// Error level (non-translatable)
|
||||
error.put_u8(b'V');
|
||||
error.put_slice(&b"FATAL\0"[..]);
|
||||
|
||||
// Error code: not sure how much this matters.
|
||||
error.put_u8(b'C');
|
||||
error.put_slice(&b"28P01\0"[..]); // system_error, see Appendix A.
|
||||
|
||||
// The short error message.
|
||||
error.put_u8(b'M');
|
||||
error.put_slice(&format!("password authentication failed for user \"{}\"\0", user).as_bytes());
|
||||
|
||||
// No more fields follow.
|
||||
error.put_u8(0);
|
||||
|
||||
// Compose the two message reply.
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
res.put_u8(b'E');
|
||||
res.put_i32(error.len() as i32 + 4);
|
||||
|
||||
res.put(error);
|
||||
|
||||
write_all(stream, res).await
|
||||
}
|
||||
|
||||
/// Respond to a SHOW SHARD command.
|
||||
pub async fn show_response(
|
||||
stream: &mut OwnedWriteHalf,
|
||||
|
||||
125
src/scram.rs
125
src/scram.rs
@@ -1,5 +1,6 @@
|
||||
// SCRAM authentication...largely copy/pasted from
|
||||
// https://github.com/sfackler/rust-postgres/.
|
||||
// SCRAM-SHA-256 authentication. Heavily inspired by
|
||||
// https://github.com/sfackler/rust-postgres/
|
||||
// SASL implementation.
|
||||
|
||||
use bytes::BytesMut;
|
||||
use hmac::{Hmac, Mac};
|
||||
@@ -12,6 +13,8 @@ use std::fmt::Write;
|
||||
use crate::constants::*;
|
||||
use crate::errors::Error;
|
||||
|
||||
/// Normalize a password string. Postgres
|
||||
/// passwords don't have to be UTF-8.
|
||||
fn normalize(pass: &[u8]) -> Vec<u8> {
|
||||
let pass = match std::str::from_utf8(pass) {
|
||||
Ok(pass) => pass,
|
||||
@@ -24,6 +27,8 @@ fn normalize(pass: &[u8]) -> Vec<u8> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Keep the SASL state through the exchange.
|
||||
/// It takes 3 messages to complete the authentication.
|
||||
pub struct ScramSha256 {
|
||||
password: String,
|
||||
salted_password: [u8; 32],
|
||||
@@ -33,6 +38,8 @@ pub struct ScramSha256 {
|
||||
}
|
||||
|
||||
impl ScramSha256 {
|
||||
/// Create the Scram state from a password. It'll automatically
|
||||
/// generate a nonce.
|
||||
pub fn new(password: &str) -> ScramSha256 {
|
||||
let mut rng = rand::thread_rng();
|
||||
let nonce = (0..NONCE_LENGTH)
|
||||
@@ -48,6 +55,7 @@ impl ScramSha256 {
|
||||
Self::from_nonce(password, &nonce)
|
||||
}
|
||||
|
||||
/// Used for testing.
|
||||
pub fn from_nonce(password: &str, nonce: &str) -> ScramSha256 {
|
||||
let message = BytesMut::from(&format!("{}n=,r={}", "n,,", nonce).as_bytes()[..]);
|
||||
|
||||
@@ -60,15 +68,16 @@ impl ScramSha256 {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the current state of the SASL authentication.
|
||||
pub fn message(&mut self) -> BytesMut {
|
||||
self.message.clone()
|
||||
}
|
||||
|
||||
/// Update the state with message received from server.
|
||||
pub fn update(&mut self, message: &BytesMut) -> Result<BytesMut, Error> {
|
||||
let server_message = Message::parse(message)?;
|
||||
|
||||
if !server_message.nonce.starts_with(&self.nonce) {
|
||||
// trace!("Bad server nonce");
|
||||
return Err(Error::ProtocolSyncError);
|
||||
}
|
||||
|
||||
@@ -82,28 +91,39 @@ impl ScramSha256 {
|
||||
&salt,
|
||||
server_message.iterations,
|
||||
);
|
||||
|
||||
// Save for verification of final server message.
|
||||
self.salted_password = salted_password;
|
||||
|
||||
let mut hmac = Hmac::<Sha256>::new_from_slice(&salted_password)
|
||||
.expect("HMAC is able to accept all key sizes");
|
||||
let mut hmac = match Hmac::<Sha256>::new_from_slice(&salted_password) {
|
||||
Ok(hmac) => hmac,
|
||||
Err(_) => return Err(Error::ServerError),
|
||||
};
|
||||
|
||||
hmac.update(b"Client Key");
|
||||
|
||||
let client_key = hmac.finalize().into_bytes();
|
||||
|
||||
let mut hash = Sha256::default();
|
||||
hash.update(client_key.as_slice());
|
||||
let stored_key = hash.finalize_fixed();
|
||||
|
||||
let stored_key = hash.finalize_fixed();
|
||||
let mut cbind_input = vec![];
|
||||
cbind_input.extend("n,,".as_bytes());
|
||||
|
||||
let cbind_input = base64::encode(&cbind_input);
|
||||
|
||||
self.message.clear();
|
||||
write!(
|
||||
|
||||
// Start writing the client reply.
|
||||
match write!(
|
||||
&mut self.message,
|
||||
"c={},r={}",
|
||||
cbind_input, server_message.nonce
|
||||
)
|
||||
.unwrap();
|
||||
) {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::ServerError),
|
||||
};
|
||||
|
||||
let auth_message = format!(
|
||||
"n=,r={},{},{}",
|
||||
@@ -112,23 +132,32 @@ impl ScramSha256 {
|
||||
String::from_utf8_lossy(&self.message[..])
|
||||
);
|
||||
|
||||
let mut hmac = Hmac::<Sha256>::new_from_slice(&stored_key)
|
||||
.expect("HMAC is able to accept all key sizes");
|
||||
let mut hmac = match Hmac::<Sha256>::new_from_slice(&stored_key) {
|
||||
Ok(hmac) => hmac,
|
||||
Err(_) => return Err(Error::ServerError),
|
||||
};
|
||||
hmac.update(auth_message.as_bytes());
|
||||
|
||||
// Save the auth message for server final message verification.
|
||||
self.auth_message = auth_message;
|
||||
|
||||
let client_signature = hmac.finalize().into_bytes();
|
||||
|
||||
// Sign the client proof.
|
||||
let mut client_proof = client_key;
|
||||
for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
|
||||
*proof ^= signature;
|
||||
}
|
||||
|
||||
write!(&mut self.message, ",p={}", base64::encode(&*client_proof)).unwrap();
|
||||
|
||||
self.auth_message = auth_message;
|
||||
match write!(&mut self.message, ",p={}", base64::encode(&*client_proof)) {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::ServerError),
|
||||
};
|
||||
|
||||
Ok(self.message.clone())
|
||||
}
|
||||
|
||||
/// Verify final server message.
|
||||
pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> {
|
||||
let final_message = FinalMessage::parse(message)?;
|
||||
|
||||
@@ -137,13 +166,17 @@ impl ScramSha256 {
|
||||
Err(_) => return Err(Error::ProtocolSyncError),
|
||||
};
|
||||
|
||||
let mut hmac = Hmac::<Sha256>::new_from_slice(&self.salted_password)
|
||||
.expect("HMAC is able to accept all key sizes");
|
||||
let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
|
||||
Ok(hmac) => hmac,
|
||||
Err(_) => return Err(Error::ServerError),
|
||||
};
|
||||
hmac.update(b"Server Key");
|
||||
let server_key = hmac.finalize().into_bytes();
|
||||
|
||||
let mut hmac = Hmac::<Sha256>::new_from_slice(&server_key)
|
||||
.expect("HMAC is able to accept all key sizes");
|
||||
let mut hmac = match Hmac::<Sha256>::new_from_slice(&server_key) {
|
||||
Ok(hmac) => hmac,
|
||||
Err(_) => return Err(Error::ServerError),
|
||||
};
|
||||
hmac.update(self.auth_message.as_bytes());
|
||||
|
||||
match hmac.verify_slice(&verifier) {
|
||||
@@ -152,7 +185,7 @@ impl ScramSha256 {
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/sfackler/rust-postgres/blob/c3a029e60c1c0bd0be947049859b8fa5bd5ac220/postgres-protocol/src/authentication/sasl.rs#L35
|
||||
/// Hash the password with the salt i-times.
|
||||
fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] {
|
||||
let mut hmac =
|
||||
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
|
||||
@@ -176,7 +209,7 @@ impl ScramSha256 {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
/// Parse the server challenge.
|
||||
struct Message {
|
||||
nonce: String,
|
||||
salt: String,
|
||||
@@ -184,47 +217,21 @@ struct Message {
|
||||
}
|
||||
|
||||
impl Message {
|
||||
/// Parse the server SASL challenge.
|
||||
fn parse(message: &BytesMut) -> Result<Message, Error> {
|
||||
if !message.starts_with(b"r=") {
|
||||
let parts = String::from_utf8_lossy(&message[..])
|
||||
.split(",")
|
||||
.map(|s| s.to_string())
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
if parts.len() != 3 {
|
||||
return Err(Error::ProtocolSyncError);
|
||||
}
|
||||
|
||||
let mut i = 2;
|
||||
|
||||
while message[i] != b',' && i < message.len() {
|
||||
i += 1;
|
||||
}
|
||||
|
||||
let nonce = String::from_utf8_lossy(&message[2..i]).to_string();
|
||||
|
||||
// Skip the ,
|
||||
i += 1;
|
||||
|
||||
if !&message[i..].starts_with(b"s=") {
|
||||
return Err(Error::ProtocolSyncError);
|
||||
}
|
||||
|
||||
// Skip the s=
|
||||
i += 2;
|
||||
|
||||
let s = i;
|
||||
while message[i] != b',' && i < message.len() {
|
||||
i += 1;
|
||||
}
|
||||
|
||||
let salt = String::from_utf8_lossy(&message[s..i]).to_string();
|
||||
|
||||
// Skip the ,
|
||||
i += 1;
|
||||
|
||||
if !&message[i..].starts_with(b"i=") {
|
||||
return Err(Error::ProtocolSyncError);
|
||||
}
|
||||
|
||||
i += 2;
|
||||
|
||||
let iterations = match String::from_utf8_lossy(&message[i..]).parse::<u32>() {
|
||||
Ok(it) => it,
|
||||
let nonce = str::replace(&parts[0], "r=", "");
|
||||
let salt = str::replace(&parts[1], "s=", "");
|
||||
let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
|
||||
Ok(iterations) => iterations,
|
||||
Err(_) => return Err(Error::ProtocolSyncError),
|
||||
};
|
||||
|
||||
@@ -236,13 +243,15 @@ impl Message {
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse server final validation message.
|
||||
struct FinalMessage {
|
||||
value: String,
|
||||
}
|
||||
|
||||
impl FinalMessage {
|
||||
/// Parse the server final validation message.
|
||||
pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
|
||||
if !message.starts_with(b"v=") {
|
||||
if !message.starts_with(b"v=") || message.len() < 4 {
|
||||
return Err(Error::ProtocolSyncError);
|
||||
}
|
||||
|
||||
|
||||
@@ -137,6 +137,7 @@ impl Server {
|
||||
debug!("Starting SASL authentication");
|
||||
let sasl_len = (len - 8) as usize;
|
||||
let mut sasl_auth = vec![0u8; sasl_len];
|
||||
|
||||
match stream.read_exact(&mut sasl_auth).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
@@ -147,16 +148,22 @@ impl Server {
|
||||
if sasl_type == SCRAM_SHA_256 {
|
||||
debug!("Using {}", SCRAM_SHA_256);
|
||||
|
||||
// Send client message
|
||||
// Generate client message.
|
||||
let sasl_response = scram.message();
|
||||
|
||||
// SASLInitialResponse (F)
|
||||
let mut res = BytesMut::new();
|
||||
res.put_u8(b'p');
|
||||
|
||||
// length + String length + length + length of sasl response
|
||||
res.put_i32(
|
||||
4 + SCRAM_SHA_256.len() as i32
|
||||
+ 1
|
||||
+ sasl_response.len() as i32
|
||||
+ 4,
|
||||
4 // i32 size
|
||||
+ SCRAM_SHA_256.len() as i32 // length of SASL version string,
|
||||
+ 1 // Null terminator for the SASL version string,
|
||||
+ 4 // i32 size
|
||||
+ sasl_response.len() as i32, // length of SASL response
|
||||
);
|
||||
|
||||
res.put_slice(&format!("{}\0", SCRAM_SHA_256).as_bytes()[..]);
|
||||
res.put_i32(sasl_response.len() as i32);
|
||||
res.put(sasl_response);
|
||||
@@ -181,6 +188,7 @@ impl Server {
|
||||
let msg = BytesMut::from(&sasl_data[..]);
|
||||
let sasl_response = scram.update(&msg)?;
|
||||
|
||||
// SASLResponse
|
||||
let mut res = BytesMut::new();
|
||||
res.put_u8(b'p');
|
||||
res.put_i32(4 + sasl_response.len() as i32);
|
||||
|
||||
Reference in New Issue
Block a user