Compare commits

..

9 Commits

Author SHA1 Message Date
Lev Kokotov
d865d9f9d8 readme 2022-06-20 06:20:12 -07:00
Lev Kokotov
d3310a62c2 Client md5 auth and clean up scram (#77)
* client md5 auth and clean up scram

* add pw

* add user

* add user

* log
2022-06-20 06:15:54 -07:00
Lev Kokotov
d412238f47 Implement SCRAM-SHA-256 for server authentication (PG14) (#76)
* Implement SCRAM-SHA-256

* test it

* trace

* move to community for auth

* hmm
2022-06-18 18:36:00 -07:00
dependabot[bot]
7782933f59 Bump regex from 1.5.4 to 1.5.5 (#75)
Bumps [regex](https://github.com/rust-lang/regex) from 1.5.4 to 1.5.5.
- [Release notes](https://github.com/rust-lang/regex/releases)
- [Changelog](https://github.com/rust-lang/regex/blob/master/CHANGELOG.md)
- [Commits](https://github.com/rust-lang/regex/compare/1.5.4...1.5.5)

---
updated-dependencies:
- dependency-name: regex
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2022-06-06 19:59:50 -07:00
Lev Kokotov
bac4e1f52c Only set application_name if it's different (#74)
* Only set application_name if it's different

* keep server named pgcat until something else changes
2022-06-05 09:48:06 -07:00
Lev Kokotov
37e3a86881 Pass application_name to server (#73)
* Pass application_name to server

* fmt
2022-06-03 00:15:50 -07:00
Lev Kokotov
61db13f614 Fix memory leak in client/server mapping (#71) 2022-05-18 16:24:03 -07:00
Lev Kokotov
fe32b5ef17 Reduce traffic on the stats channel (#69) 2022-05-17 13:05:25 -07:00
Lev Kokotov
54699222f8 Possible fix for clients waiting stat leak (#68) 2022-05-14 21:35:33 -07:00
14 changed files with 694 additions and 67 deletions

View File

@@ -12,13 +12,15 @@ jobs:
- image: cimg/rust:1.58.1
environment:
RUST_LOG: info
- image: cimg/postgres:14.0
auth:
username: mydockerhub-user
password: $DOCKERHUB_PASSWORD
- image: postgres:14
# auth:
# username: mydockerhub-user
# password: $DOCKERHUB_PASSWORD
environment:
POSTGRES_USER: postgres
POSTGRES_DB: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_HOST_AUTH_METHOD: scram-sha-256
# Add steps to the job
# See: https://circleci.com/docs/2.0/configuration-reference/#steps
steps:

View File

@@ -12,7 +12,8 @@ function start_pgcat() {
}
# Setup the database with shards and user
psql -e -h 127.0.0.1 -p 5432 -U postgres -f tests/sharding/query_routing_setup.sql
PGPASSWORD=postgres psql -e -h 127.0.0.1 -p 5432 -U postgres -f tests/sharding/query_routing_setup.sql
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard0 -i
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard1 -i
PGPASSWORD=sharding_user pgbench -h 127.0.0.1 -U sharding_user shard2 -i
@@ -30,26 +31,28 @@ toxiproxy-cli create -l 127.0.0.1:5433 -u 127.0.0.1:5432 postgres_replica
start_pgcat "info"
export PGPASSWORD=sharding_user
# pgbench test
pgbench -i -h 127.0.0.1 -p 6432
pgbench -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol simple -f tests/pgbench/simple.sql
pgbench -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol extended
pgbench -U sharding_user -i -h 127.0.0.1 -p 6432
pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol simple -f tests/pgbench/simple.sql
pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol extended
# COPY TO STDOUT test
psql -h 127.0.0.1 -p 6432 -c 'COPY (SELECT * FROM pgbench_accounts LIMIT 15) TO STDOUT;' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'COPY (SELECT * FROM pgbench_accounts LIMIT 15) TO STDOUT;' > /dev/null
# Query cancellation test
(psql -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(5)' || true) &
(psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(5)' || true) &
killall psql -s SIGINT
# Sharding insert
psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_insert.sql
psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_insert.sql
# Sharding select
psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_select.sql > /dev/null
psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_select.sql > /dev/null
# Replica/primary selection & more sharding tests
psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_primary_replica.sql > /dev/null
psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_primary_replica.sql > /dev/null
#
# ActiveRecord tests
@@ -61,25 +64,25 @@ cd tests/ruby && \
cd ../..
# Admin tests
psql -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW LISTS' > /dev/null
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW POOLS' > /dev/null
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW VERSION' > /dev/null
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev/null # will ignore
(! psql -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null)
psql -U sharding_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW LISTS' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW POOLS' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW VERSION' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev/null # will ignore
(! psql -U sharding_user -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null)
# Start PgCat in debug to demonstrate failover better
start_pgcat "debug"
start_pgcat "trace"
# Add latency to the replica at port 5433 slightly above the healthcheck timeout
toxiproxy-cli toxic add -t latency -a latency=300 postgres_replica
sleep 1
# Note the failover in the logs
timeout 5 psql -e -h 127.0.0.1 -p 6432 <<-EOF
timeout 5 psql -U sharding_user -e -h 127.0.0.1 -p 6432 <<-EOF
SELECT 1;
SELECT 1;
SELECT 1;
@@ -97,7 +100,7 @@ sed -i 's/pool_mode = "transaction"/pool_mode = "session"/' pgcat.toml
kill -SIGHUP $(pgrep pgcat)
# Prepared statements that will only work in session mode
pgbench -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol prepared
pgbench -U sharding_user -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol prepared
# Attempt clean shut down
killall pgcat -s SIGINT

93
Cargo.lock generated
View File

@@ -45,6 +45,12 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cdb031dd78e28731d87d56cc8ffef4a8f36ca26c38fe2de700543e627f8a464a"
[[package]]
name = "base64"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "904dfeac50f3cdaba28fc6f57fdcddb75f49ed61346676a78c4ffe55877802fd"
[[package]]
name = "bb8"
version = "0.7.1"
@@ -109,22 +115,23 @@ dependencies = [
[[package]]
name = "crypto-common"
version = "0.1.1"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "683d6b536309245c849479fba3da410962a43ed8e51c26b729208ec0ac2798d0"
checksum = "57952ca27b5e3606ff4dd79b0020231aaf9d6aa76dc05fd30137538c50bd3ce8"
dependencies = [
"generic-array",
"typenum",
]
[[package]]
name = "digest"
version = "0.10.1"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b697d66081d42af4fba142d56918a3cb21dc8eb63372c6b85d14f44fb9c5979b"
checksum = "f2fb860ca6fafa5552fb6d0e816a69c8e49f0908bf524e30a90d97c85892d506"
dependencies = [
"block-buffer",
"crypto-common",
"generic-array",
"subtle",
]
[[package]]
@@ -205,6 +212,15 @@ dependencies = [
"libc",
]
[[package]]
name = "hmac"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
dependencies = [
"digest",
]
[[package]]
name = "humantime"
version = "2.1.0"
@@ -352,14 +368,16 @@ dependencies = [
[[package]]
name = "pgcat"
version = "0.1.0-beta2"
version = "0.2.0-beta1"
dependencies = [
"arc-swap",
"async-trait",
"base64",
"bb8",
"bytes",
"chrono",
"env_logger",
"hmac",
"log",
"md-5",
"num_cpus",
@@ -370,7 +388,9 @@ dependencies = [
"serde",
"serde_derive",
"sha-1",
"sha2",
"sqlparser",
"stringprep",
"tokio",
"toml",
]
@@ -462,9 +482,9 @@ dependencies = [
[[package]]
name = "regex"
version = "1.5.4"
version = "1.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d07a8629359eb56f1e2fb1652bb04212c072a87ba68546a04065d525673ac461"
checksum = "1a11647b6b25ff05a515cb92c365cec08801e83423a235b51e231e1808747286"
dependencies = [
"aho-corasick",
"memchr",
@@ -511,6 +531,17 @@ dependencies = [
"digest",
]
[[package]]
name = "sha2"
version = "0.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55deaec60f81eefe3cce0dc50bda92d6d8e88f2a27df7c5033b42afeb1ed2676"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "signal-hook-registry"
version = "1.4.0"
@@ -541,6 +572,22 @@ dependencies = [
"log",
]
[[package]]
name = "stringprep"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ee348cb74b87454fff4b551cbf727025810a004f88aeacae7f85b87f4e9a1c1"
dependencies = [
"unicode-bidi",
"unicode-normalization",
]
[[package]]
name = "subtle"
version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601"
[[package]]
name = "syn"
version = "1.0.86"
@@ -572,6 +619,21 @@ dependencies = [
"winapi",
]
[[package]]
name = "tinyvec"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50"
dependencies = [
"tinyvec_macros",
]
[[package]]
name = "tinyvec_macros"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
[[package]]
name = "tokio"
version = "1.16.1"
@@ -617,6 +679,21 @@ version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dcf81ac59edc17cc8697ff311e8f5ef2d99fcbd9817b34cec66f90b6c3dfd987"
[[package]]
name = "unicode-bidi"
version = "0.3.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "099b7128301d285f79ddd55b9a83d5e6b9e97c92e0ea0daebee7263e932de992"
[[package]]
name = "unicode-normalization"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d54590932941a9e9266f0832deed84ebe1bf2e4c9e4a3554d393d18f5e854bf9"
dependencies = [
"tinyvec",
]
[[package]]
name = "unicode-xid"
version = "0.2.2"

View File

@@ -1,6 +1,6 @@
[package]
name = "pgcat"
version = "0.1.0-beta2"
version = "0.2.1-beta1"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -25,3 +25,7 @@ log = "0.4"
arc-swap = "1"
env_logger = "0.9"
parking_lot = "0.11"
hmac = "0.12"
sha2 = "0.10"
base64 = "0.13"
stringprep = "0.1"

View File

@@ -9,19 +9,19 @@ PostgreSQL pooler (like PgBouncer) with sharding, load balancing and failover su
**Beta**: looking for beta testers, see [#35](https://github.com/levkk/pgcat/issues/35).
## Features
| **Feature** | **Status** | **Comments** |
|--------------------------------|-----------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|
| Transaction pooling | :white_check_mark: | Identical to PgBouncer. |
| Session pooling | :white_check_mark: | Identical to PgBouncer. |
| `COPY` support | :white_check_mark: | Both `COPY TO` and `COPY FROM` are supported. |
| Query cancellation | :white_check_mark: | Supported both in transaction and session pooling modes. |
| Load balancing of read queries | :white_check_mark: | Using round-robin between replicas. Primary is included when `primary_reads_enabled` is enabled (default). |
| Sharding | :white_check_mark: | Transactions are sharded using `SET SHARD TO` and `SET SHARDING KEY TO` syntax extensions; see examples below. |
| Failover | :white_check_mark: | Replicas are tested with a health check. If a health check fails, remaining replicas are attempted; see below for algorithm description and examples. |
| Statistics | :white_check_mark: | Statistics available in the admin database (`pgcat` and `pgbouncer`) with `SHOW STATS`, `SHOW POOLS` and others. |
| Live configuration reloading | :white_check_mark: | Reload supported settings with a `SIGHUP` to the process, e.g. `kill -s SIGHUP $(pgrep pgcat)` or `RELOAD` query issued to the admin database. |
| Client authentication | :x: :wrench: | On the roadmap; currently all clients are allowed to connect and one user is used to connect to Postgres. |
| Admin database | :white_check_mark: | The admin database, similar to PgBouncer's, allows to query for statistics and reload the configuration. |
| **Feature** | **Status** | **Comments** |
|--------------------------------|-----------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|
| Transaction pooling | :white_check_mark: | Identical to PgBouncer. |
| Session pooling | :white_check_mark: | Identical to PgBouncer. |
| `COPY` support | :white_check_mark: | Both `COPY TO` and `COPY FROM` are supported. |
| Query cancellation | :white_check_mark: | Supported both in transaction and session pooling modes. |
| Load balancing of read queries | :white_check_mark: | Using round-robin between replicas. Primary is included when `primary_reads_enabled` is enabled (default). |
| Sharding | :white_check_mark: | Transactions are sharded using `SET SHARD TO` and `SET SHARDING KEY TO` syntax extensions; see examples below. |
| Failover | :white_check_mark: | Replicas are tested with a health check. If a health check fails, remaining replicas are attempted; see below for algorithm description and examples. |
| Statistics | :white_check_mark: | Statistics available in the admin database (`pgcat` and `pgbouncer`) with `SHOW STATS`, `SHOW POOLS` and others. |
| Live configuration reloading | :white_check_mark: | Reload supported settings with a `SIGHUP` to the process, e.g. `kill -s SIGHUP $(pgrep pgcat)` or `RELOAD` query issued to the admin database. |
| Client authentication | :white_check_mark: :wrench: | MD5 password authentication is supported, SCRAM is on the roadmap; one user is used to connect to Postgres with both SCRAM and MD5 supported. |
| Admin database | :white_check_mark: | The admin database, similar to PgBouncer's, allows to query for statistics and reload the configuration. |
## Deployment

View File

@@ -72,9 +72,9 @@ impl Client {
server_info: BytesMut,
stats: Reporter,
) -> Result<Client, Error> {
let config = get_config();
let config = get_config().clone();
let transaction_mode = config.general.pool_mode.starts_with("t");
drop(config);
// drop(config);
loop {
trace!("Waiting for StartupMessage");
@@ -108,14 +108,51 @@ impl Client {
// Regular startup message.
PROTOCOL_VERSION_NUMBER => {
trace!("Got StartupMessage");
// TODO: perform actual auth.
let parameters = parse_startup(bytes.clone())?;
// Generate random backend ID and secret key
let process_id: i32 = rand::random();
let secret_key: i32 = rand::random();
// Perform MD5 authentication.
// TODO: Add SASL support.
let salt = md5_challenge(&mut stream).await?;
let code = match stream.read_u8().await {
Ok(p) => p,
Err(_) => return Err(Error::SocketError),
};
// PasswordMessage
if code as char != 'p' {
debug!("Expected p, got {}", code as char);
return Err(Error::ProtocolSyncError);
}
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError),
};
let mut password_response = vec![0u8; (len - 4) as usize];
match stream.read_exact(&mut password_response).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
// Compare server and client hashes.
let password_hash =
md5_hash_password(&config.user.name, &config.user.password, &salt);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut stream, &config.user.name).await?;
return Err(Error::ClientError);
}
debug!("Password authentication successful");
auth_ok(&mut stream).await?;
write_all(&mut stream, server_info).await?;
backend_key_data(&mut stream, process_id, secret_key).await?;
@@ -342,6 +379,14 @@ impl Client {
server.address()
);
// Set application_name if any.
// TODO: investigate other parameters and set them too.
if self.parameters.contains_key("application_name") {
server
.set_name(&self.parameters["application_name"])
.await?;
}
// Transaction loop. Multiple queries can be issued by the client here.
// The connection belongs to the client until the transaction is over,
// or until the client disconnects if we are in session mode.
@@ -361,6 +406,7 @@ impl Client {
if server.in_transaction() {
server.query("ROLLBACK").await?;
server.query("DISCARD ALL").await?;
server.set_name("pgcat").await?;
}
return Err(err);
@@ -432,8 +478,11 @@ impl Client {
if server.in_transaction() {
server.query("ROLLBACK").await?;
server.query("DISCARD ALL").await?;
server.set_name("pgcat").await?;
}
self.release();
return Ok(());
}

View File

@@ -14,6 +14,13 @@ pub const CANCEL_REQUEST_CODE: i32 = 80877102;
// AuthenticationMD5Password
pub const MD5_ENCRYPTED_PASSWORD: i32 = 5;
// SASL
pub const SASL: i32 = 10;
pub const SASL_CONTINUE: i32 = 11;
pub const SASL_FINAL: i32 = 12;
pub const SCRAM_SHA_256: &str = "SCRAM-SHA-256";
pub const NONCE_LENGTH: usize = 24;
// AuthenticationOk
pub const AUTHENTICATION_SUCCESSFUL: i32 = 0;

View File

@@ -9,4 +9,5 @@ pub enum Error {
ServerError,
BadConfig,
AllServersDown,
ClientError,
}

View File

@@ -54,6 +54,7 @@ mod errors;
mod messages;
mod pool;
mod query_router;
mod scram;
mod server;
mod sharding;
mod stats;

View File

@@ -40,6 +40,26 @@ pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
Ok(write_all(stream, auth_ok).await?)
}
/// Generate md5 password challenge.
pub async fn md5_challenge(stream: &mut TcpStream) -> Result<[u8; 4], Error> {
// let mut rng = rand::thread_rng();
let salt: [u8; 4] = [
rand::random(),
rand::random(),
rand::random(),
rand::random(),
];
let mut res = BytesMut::new();
res.put_u8(b'R');
res.put_i32(12);
res.put_i32(5); // MD5
res.put_slice(&salt[..]);
write_all(stream, res).await?;
Ok(salt)
}
/// Give the client the process_id and secret we generated
/// used in query cancellation.
pub async fn backend_key_data(
@@ -160,14 +180,8 @@ pub fn parse_startup(bytes: BytesMut) -> Result<HashMap<String, String>, Error>
Ok(result)
}
/// Send password challenge response to the server.
/// This is the MD5 challenge.
pub async fn md5_password(
stream: &mut TcpStream,
user: &str,
password: &str,
salt: &[u8],
) -> Result<(), Error> {
/// Create md5 password hash given a salt.
pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
let mut md5 = Md5::new();
// First pass
@@ -186,6 +200,19 @@ pub async fn md5_password(
.collect::<Vec<u8>>();
password.push(0);
password
}
/// Send password challenge response to the server.
/// This is the MD5 challenge.
pub async fn md5_password(
stream: &mut TcpStream,
user: &str,
password: &str,
salt: &[u8],
) -> Result<(), Error> {
let password = md5_hash_password(user, password, salt);
let mut message = BytesMut::with_capacity(password.len() as usize + 5);
message.put_u8(b'p');
@@ -264,6 +291,39 @@ pub async fn error_response(stream: &mut OwnedWriteHalf, message: &str) -> Resul
Ok(write_all_half(stream, res).await?)
}
pub async fn wrong_password(stream: &mut TcpStream, user: &str) -> Result<(), Error> {
let mut error = BytesMut::new();
// Error level
error.put_u8(b'S');
error.put_slice(&b"FATAL\0"[..]);
// Error level (non-translatable)
error.put_u8(b'V');
error.put_slice(&b"FATAL\0"[..]);
// Error code: not sure how much this matters.
error.put_u8(b'C');
error.put_slice(&b"28P01\0"[..]); // system_error, see Appendix A.
// The short error message.
error.put_u8(b'M');
error.put_slice(&format!("password authentication failed for user \"{}\"\0", user).as_bytes());
// No more fields follow.
error.put_u8(0);
// Compose the two message reply.
let mut res = BytesMut::new();
res.put_u8(b'E');
res.put_i32(error.len() as i32 + 4);
res.put(error);
write_all(stream, res).await
}
/// Respond to a SHOW SHARD command.
pub async fn show_response(
stream: &mut OwnedWriteHalf,

View File

@@ -209,8 +209,6 @@ impl ConnectionPool {
let index = self.round_robin % addresses.len();
let address = &addresses[index];
self.stats.client_waiting(process_id, address.id);
// Make sure you're getting a primary or a replica
// as per request. If no specific role is requested, the first
// available will be chosen.
@@ -224,6 +222,9 @@ impl ConnectionPool {
continue;
}
// Indicate we're waiting on a server connection from a pool.
self.stats.client_waiting(process_id, address.id);
// Check if we can connect
let mut conn = match self.databases[shard][index].get().await {
Ok(conn) => conn,

320
src/scram.rs Normal file
View File

@@ -0,0 +1,320 @@
// SCRAM-SHA-256 authentication. Heavily inspired by
// https://github.com/sfackler/rust-postgres/
// SASL implementation.
use bytes::BytesMut;
use hmac::{Hmac, Mac};
use rand::{self, Rng};
use sha2::digest::FixedOutput;
use sha2::{Digest, Sha256};
use std::fmt::Write;
use crate::constants::*;
use crate::errors::Error;
/// Normalize a password string. Postgres
/// passwords don't have to be UTF-8.
fn normalize(pass: &[u8]) -> Vec<u8> {
let pass = match std::str::from_utf8(pass) {
Ok(pass) => pass,
Err(_) => return pass.to_vec(),
};
match stringprep::saslprep(pass) {
Ok(pass) => pass.into_owned().into_bytes(),
Err(_) => pass.as_bytes().to_vec(),
}
}
/// Keep the SASL state through the exchange.
/// It takes 3 messages to complete the authentication.
pub struct ScramSha256 {
password: String,
salted_password: [u8; 32],
auth_message: String,
message: BytesMut,
nonce: String,
}
impl ScramSha256 {
/// Create the Scram state from a password. It'll automatically
/// generate a nonce.
pub fn new(password: &str) -> ScramSha256 {
let mut rng = rand::thread_rng();
let nonce = (0..NONCE_LENGTH)
.map(|_| {
let mut v = rng.gen_range(0x21u8..0x7e);
if v == 0x2c {
v = 0x7e
}
v as char
})
.collect::<String>();
Self::from_nonce(password, &nonce)
}
/// Used for testing.
pub fn from_nonce(password: &str, nonce: &str) -> ScramSha256 {
let message = BytesMut::from(&format!("{}n=,r={}", "n,,", nonce).as_bytes()[..]);
ScramSha256 {
password: password.to_string(),
nonce: String::from(nonce),
message,
salted_password: [0u8; 32],
auth_message: String::new(),
}
}
/// Get the current state of the SASL authentication.
pub fn message(&mut self) -> BytesMut {
self.message.clone()
}
/// Update the state with message received from server.
pub fn update(&mut self, message: &BytesMut) -> Result<BytesMut, Error> {
let server_message = Message::parse(message)?;
if !server_message.nonce.starts_with(&self.nonce) {
return Err(Error::ProtocolSyncError);
}
let salt = match base64::decode(&server_message.salt) {
Ok(salt) => salt,
Err(_) => return Err(Error::ProtocolSyncError),
};
let salted_password = Self::hi(
&normalize(&self.password.as_bytes()[..]),
&salt,
server_message.iterations,
);
// Save for verification of final server message.
self.salted_password = salted_password;
let mut hmac = match Hmac::<Sha256>::new_from_slice(&salted_password) {
Ok(hmac) => hmac,
Err(_) => return Err(Error::ServerError),
};
hmac.update(b"Client Key");
let client_key = hmac.finalize().into_bytes();
let mut hash = Sha256::default();
hash.update(client_key.as_slice());
let stored_key = hash.finalize_fixed();
let mut cbind_input = vec![];
cbind_input.extend("n,,".as_bytes());
let cbind_input = base64::encode(&cbind_input);
self.message.clear();
// Start writing the client reply.
match write!(
&mut self.message,
"c={},r={}",
cbind_input, server_message.nonce
) {
Ok(_) => (),
Err(_) => return Err(Error::ServerError),
};
let auth_message = format!(
"n=,r={},{},{}",
self.nonce,
String::from_utf8_lossy(&message[..]),
String::from_utf8_lossy(&self.message[..])
);
let mut hmac = match Hmac::<Sha256>::new_from_slice(&stored_key) {
Ok(hmac) => hmac,
Err(_) => return Err(Error::ServerError),
};
hmac.update(auth_message.as_bytes());
// Save the auth message for server final message verification.
self.auth_message = auth_message;
let client_signature = hmac.finalize().into_bytes();
// Sign the client proof.
let mut client_proof = client_key;
for (proof, signature) in client_proof.iter_mut().zip(client_signature) {
*proof ^= signature;
}
match write!(&mut self.message, ",p={}", base64::encode(&*client_proof)) {
Ok(_) => (),
Err(_) => return Err(Error::ServerError),
};
Ok(self.message.clone())
}
/// Verify final server message.
pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> {
let final_message = FinalMessage::parse(message)?;
let verifier = match base64::decode(&final_message.value) {
Ok(verifier) => verifier,
Err(_) => return Err(Error::ProtocolSyncError),
};
let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
Ok(hmac) => hmac,
Err(_) => return Err(Error::ServerError),
};
hmac.update(b"Server Key");
let server_key = hmac.finalize().into_bytes();
let mut hmac = match Hmac::<Sha256>::new_from_slice(&server_key) {
Ok(hmac) => hmac,
Err(_) => return Err(Error::ServerError),
};
hmac.update(self.auth_message.as_bytes());
match hmac.verify_slice(&verifier) {
Ok(_) => Ok(()),
Err(_) => return Err(Error::ServerError),
}
}
/// Hash the password with the salt i-times.
fn hi(str: &[u8], salt: &[u8], i: u32) -> [u8; 32] {
let mut hmac =
Hmac::<Sha256>::new_from_slice(str).expect("HMAC is able to accept all key sizes");
hmac.update(salt);
hmac.update(&[0, 0, 0, 1]);
let mut prev = hmac.finalize().into_bytes();
let mut hi = prev;
for _ in 1..i {
let mut hmac = Hmac::<Sha256>::new_from_slice(str).expect("already checked above");
hmac.update(&prev);
prev = hmac.finalize().into_bytes();
for (hi, prev) in hi.iter_mut().zip(prev) {
*hi ^= prev;
}
}
hi.into()
}
}
/// Parse the server challenge.
struct Message {
nonce: String,
salt: String,
iterations: u32,
}
impl Message {
/// Parse the server SASL challenge.
fn parse(message: &BytesMut) -> Result<Message, Error> {
let parts = String::from_utf8_lossy(&message[..])
.split(",")
.map(|s| s.to_string())
.collect::<Vec<String>>();
if parts.len() != 3 {
return Err(Error::ProtocolSyncError);
}
let nonce = str::replace(&parts[0], "r=", "");
let salt = str::replace(&parts[1], "s=", "");
let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
Ok(iterations) => iterations,
Err(_) => return Err(Error::ProtocolSyncError),
};
Ok(Message {
nonce,
salt,
iterations,
})
}
}
/// Parse server final validation message.
struct FinalMessage {
value: String,
}
impl FinalMessage {
/// Parse the server final validation message.
pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
if !message.starts_with(b"v=") || message.len() < 4 {
return Err(Error::ProtocolSyncError);
}
Ok(FinalMessage {
value: String::from_utf8_lossy(&message[2..]).to_string(),
})
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn parse_server_first_message() {
let message = BytesMut::from(
&"r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096".as_bytes()[..],
);
let message = Message::parse(&message).unwrap();
assert_eq!(message.nonce, "fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j");
assert_eq!(message.salt, "QSXCR+Q6sek8bf92");
assert_eq!(message.iterations, 4096);
}
#[test]
fn parse_server_last_message() {
let f = FinalMessage::parse(&BytesMut::from(
&"v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw".as_bytes()[..],
))
.unwrap();
assert_eq!(
f.value,
"U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw".to_string()
);
}
// recorded auth exchange from psql
#[test]
fn exchange() {
let password = "foobar";
let nonce = "9IZ2O01zb9IgiIZ1WJ/zgpJB";
let client_first = "n,,n=,r=9IZ2O01zb9IgiIZ1WJ/zgpJB";
let server_first =
"r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,s=fs3IXBy7U7+IvVjZ,i\
=4096";
let client_final =
"c=biws,r=9IZ2O01zb9IgiIZ1WJ/zgpJBjx/oIRLs02gGSHcw1KEty3eY,p=AmNKosjJzS3\
1NTlQYNs5BTeQjdHdk7lOflDo5re2an8=";
let server_final = "v=U+ppxD5XUKtradnv8e2MkeupiA8FU87Sg8CXzXHDAzw=";
let mut scram = ScramSha256::from_nonce(password, nonce);
let message = scram.message();
assert_eq!(std::str::from_utf8(&message).unwrap(), client_first);
let result = scram
.update(&BytesMut::from(&server_first.as_bytes()[..]))
.unwrap();
assert_eq!(std::str::from_utf8(&result).unwrap(), client_final);
scram
.finish(&BytesMut::from(&server_final.as_bytes()[..]))
.unwrap();
}
}

View File

@@ -12,6 +12,7 @@ use crate::config::{Address, User};
use crate::constants::*;
use crate::errors::Error;
use crate::messages::*;
use crate::scram::ScramSha256;
use crate::stats::Reporter;
use crate::ClientServerMap;
@@ -54,6 +55,9 @@ pub struct Server {
/// Reports various metrics, e.g. data sent & received.
stats: Reporter,
/// Application name using the server at the moment.
application_name: String,
}
impl Server {
@@ -86,6 +90,8 @@ impl Server {
// We'll be handling multiple packets, but they will all be structured the same.
// We'll loop here until this exchange is complete.
let mut scram = ScramSha256::new(&user.password);
loop {
let code = match stream.read_u8().await {
Ok(code) => code as char,
@@ -127,6 +133,91 @@ impl Server {
AUTHENTICATION_SUCCESSFUL => (),
SASL => {
debug!("Starting SASL authentication");
let sasl_len = (len - 8) as usize;
let mut sasl_auth = vec![0u8; sasl_len];
match stream.read_exact(&mut sasl_auth).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]);
if sasl_type == SCRAM_SHA_256 {
debug!("Using {}", SCRAM_SHA_256);
// Generate client message.
let sasl_response = scram.message();
// SASLInitialResponse (F)
let mut res = BytesMut::new();
res.put_u8(b'p');
// length + String length + length + length of sasl response
res.put_i32(
4 // i32 size
+ SCRAM_SHA_256.len() as i32 // length of SASL version string,
+ 1 // Null terminator for the SASL version string,
+ 4 // i32 size
+ sasl_response.len() as i32, // length of SASL response
);
res.put_slice(&format!("{}\0", SCRAM_SHA_256).as_bytes()[..]);
res.put_i32(sasl_response.len() as i32);
res.put(sasl_response);
write_all(&mut stream, res).await?;
} else {
error!("Unsupported SCRAM version: {}", sasl_type);
return Err(Error::ServerError);
}
}
SASL_CONTINUE => {
trace!("Continuing SASL");
let mut sasl_data = vec![0u8; (len - 8) as usize];
match stream.read_exact(&mut sasl_data).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
let msg = BytesMut::from(&sasl_data[..]);
let sasl_response = scram.update(&msg)?;
// SASLResponse
let mut res = BytesMut::new();
res.put_u8(b'p');
res.put_i32(4 + sasl_response.len() as i32);
res.put(sasl_response);
write_all(&mut stream, res).await?;
}
SASL_FINAL => {
trace!("Final SASL");
let mut sasl_final = vec![0u8; len as usize - 8];
match stream.read_exact(&mut sasl_final).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
};
match scram.finish(&BytesMut::from(&sasl_final[..])) {
Ok(_) => {
debug!("SASL authentication successful");
}
Err(err) => {
debug!("SASL authentication failed");
return Err(err);
}
};
}
_ => {
error!("Unsupported authentication mechanism: {}", auth_code);
return Err(Error::ServerError);
@@ -210,7 +301,7 @@ impl Server {
let (read, write) = stream.into_split();
return Ok(Server {
let mut server = Server {
address: address.clone(),
read: BufReader::new(read),
write: write,
@@ -224,7 +315,12 @@ impl Server {
client_server_map: client_server_map,
connected_at: chrono::offset::Utc::now().naive_utc(),
stats: stats,
});
application_name: String::new(),
};
server.set_name("pgcat").await?;
return Ok(server);
}
// We have an unexpected message from the server during this exchange.
@@ -448,9 +544,14 @@ impl Server {
/// A shorthand for `SET application_name = $1`.
#[allow(dead_code)]
pub async fn set_name(&mut self, name: &str) -> Result<(), Error> {
Ok(self
.query(&format!("SET application_name = '{}'", name))
.await?)
if self.application_name != name {
self.application_name = name.to_string();
Ok(self
.query(&format!("SET application_name = '{}'", name))
.await?)
} else {
Ok(())
}
}
/// Get the servers address.

View File

@@ -16,6 +16,7 @@ ActiveRecord::Base.establish_connection(
username: 'sharding_user',
password: 'sharding_user',
database: 'rails_dev',
application_name: 'testing_pgcat',
prepared_statements: false, # Transaction mode
advisory_locks: false # Same
)
@@ -116,7 +117,7 @@ end
# Test evil clients
def poorly_behaved_client
conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/rails_dev")
conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/rails_dev?application_name=testing_pgcat")
conn.async_exec 'BEGIN'
conn.async_exec 'SELECT 1'