mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
Compare commits
2 Commits
v1.0.0
...
levkk-fix-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d4d83a6fe7 | ||
|
|
5f745859c0 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,2 +1,5 @@
|
||||
/target
|
||||
*.deb
|
||||
.idea/*
|
||||
tests/ruby/.bundle/*
|
||||
tests/ruby/vendor/*
|
||||
24
Cargo.lock
generated
24
Cargo.lock
generated
@@ -220,6 +220,12 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.117"
|
||||
@@ -369,6 +375,7 @@ dependencies = [
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"sha-1",
|
||||
"sqlparser",
|
||||
"statsd",
|
||||
@@ -478,6 +485,12 @@ version = "0.6.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f"
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.1.0"
|
||||
@@ -501,6 +514,17 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.79"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e8d9fa5c3b304765ce1fd9c4c8a3de2c8db365a5b91be52f186efc675681d95"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha-1"
|
||||
version = "0.10.0"
|
||||
|
||||
@@ -17,6 +17,7 @@ sha-1 = "0.10"
|
||||
toml = "0.5"
|
||||
serde = "1"
|
||||
serde_derive = "1"
|
||||
serde_json = "1"
|
||||
regex = "1"
|
||||
num_cpus = "1"
|
||||
once_cell = "1"
|
||||
|
||||
12
pgcat.toml
12
pgcat.toml
@@ -48,8 +48,8 @@ password = "sharding_user"
|
||||
|
||||
# [ host, port, role ]
|
||||
servers = [
|
||||
[ "127.0.0.1", 5432, "primary" ],
|
||||
[ "localhost", 5432, "replica" ],
|
||||
["127.0.0.1", 5432, "primary"],
|
||||
["localhost", 5432, "replica"],
|
||||
# [ "127.0.1.1", 5432, "replica" ],
|
||||
]
|
||||
# Database name (e.g. "postgres")
|
||||
@@ -58,8 +58,8 @@ database = "shard0"
|
||||
[shards.1]
|
||||
# [ host, port, role ]
|
||||
servers = [
|
||||
[ "127.0.0.1", 5432, "primary" ],
|
||||
[ "localhost", 5432, "replica" ],
|
||||
["127.0.0.1", 5432, "primary"],
|
||||
["localhost", 5432, "replica"],
|
||||
# [ "127.0.1.1", 5432, "replica" ],
|
||||
]
|
||||
database = "shard1"
|
||||
@@ -67,8 +67,8 @@ database = "shard1"
|
||||
[shards.2]
|
||||
# [ host, port, role ]
|
||||
servers = [
|
||||
[ "127.0.0.1", 5432, "primary" ],
|
||||
[ "localhost", 5432, "replica" ],
|
||||
["127.0.0.1", 5432, "primary"],
|
||||
["localhost", 5432, "replica"],
|
||||
# [ "127.0.1.1", 5432, "replica" ],
|
||||
]
|
||||
database = "shard2"
|
||||
|
||||
@@ -108,15 +108,17 @@ impl Client {
|
||||
// Regular startup message.
|
||||
PROTOCOL_VERSION_NUMBER => {
|
||||
trace!("Got StartupMessage");
|
||||
|
||||
// TODO: perform actual auth.
|
||||
let parameters = parse_startup(bytes.clone())?;
|
||||
let user = match parameters.get(&String::from("user")) {
|
||||
Some(user) => user,
|
||||
None => return Err(Error::ClientBadStartup),
|
||||
};
|
||||
start_auth(&mut stream, user).await?;
|
||||
|
||||
// Generate random backend ID and secret key
|
||||
let process_id: i32 = rand::random();
|
||||
let secret_key: i32 = rand::random();
|
||||
|
||||
auth_ok(&mut stream).await?;
|
||||
write_all(&mut stream, server_info).await?;
|
||||
backend_key_data(&mut stream, process_id, secret_key).await?;
|
||||
ready_for_query(&mut stream).await?;
|
||||
|
||||
@@ -8,5 +8,7 @@ pub enum Error {
|
||||
// ServerTimeout,
|
||||
// DirtyServer,
|
||||
BadConfig,
|
||||
BadUserList,
|
||||
AllServersDown,
|
||||
AuthenticationError
|
||||
}
|
||||
|
||||
10
src/main.rs
10
src/main.rs
@@ -50,6 +50,7 @@ use std::sync::Arc;
|
||||
mod admin;
|
||||
mod client;
|
||||
mod config;
|
||||
mod userlist;
|
||||
mod constants;
|
||||
mod errors;
|
||||
mod messages;
|
||||
@@ -94,6 +95,15 @@ async fn main() {
|
||||
}
|
||||
};
|
||||
|
||||
// Prepare user list
|
||||
match userlist::parse("userlist.json").await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("Userlist parse error: {:?}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let config = get_config();
|
||||
|
||||
let addr = format!("{}:{}", config.general.host, config.general.port);
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
/// Helper functions to send one-off protocol messages
|
||||
/// and handle TcpStream (TCP socket).
|
||||
|
||||
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use md5::{Digest, Md5};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
|
||||
@@ -7,10 +9,16 @@ use tokio::net::{
|
||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
TcpStream,
|
||||
};
|
||||
use log::{error};
|
||||
|
||||
use crate::errors::Error;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use rand::Rng;
|
||||
|
||||
use crate::userlist::get_user_list;
|
||||
|
||||
|
||||
/// Postgres data type mappings
|
||||
/// used in RowDescription ('T') message.
|
||||
pub enum DataType {
|
||||
@@ -29,6 +37,95 @@ impl From<&DataType> for i32 {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
1. Generate salt (4 bytes of random data)
|
||||
md5(concat(md5(concat(password, username)), random-salt)))
|
||||
2. Send md5 auth request
|
||||
3. recieve PasswordMessage with salt.
|
||||
4. refactor md5_password function to be reusable
|
||||
5. check username hash combo against file
|
||||
6. AuthenticationOk or ErrorResponse
|
||||
**/
|
||||
pub async fn start_auth(stream: &mut TcpStream, user: &str) -> Result<(), Error> {
|
||||
//Generate random 4 byte salt
|
||||
let salt = rand::random::<u32>();
|
||||
|
||||
// Send AuthenticationMD5Password request
|
||||
send_md5_request(stream, salt).await?;
|
||||
|
||||
let code = match stream.read_u8().await {
|
||||
Ok(code) => code as char,
|
||||
Err(_) => return Err(Error::AuthenticationError),
|
||||
};
|
||||
|
||||
match code {
|
||||
// Password response
|
||||
'p' => {
|
||||
fetch_password_and_authenticate(stream, user, &salt).await?;
|
||||
Ok(auth_ok(stream).await?)
|
||||
}
|
||||
_ => {
|
||||
error!("Unknown code: {}", code);
|
||||
return Err(Error::AuthenticationError);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_md5_request(stream: &mut TcpStream, salt: u32) -> Result<(), Error> {
|
||||
let mut authentication_md5password = BytesMut::with_capacity(12);
|
||||
authentication_md5password.put_u8(b'R');
|
||||
authentication_md5password.put_i32(12);
|
||||
authentication_md5password.put_i32(5);
|
||||
authentication_md5password.put_u32(salt);
|
||||
|
||||
// Send AuthenticationMD5Password request
|
||||
Ok(write_all(stream, authentication_md5password).await?)
|
||||
}
|
||||
|
||||
pub async fn fetch_password_and_authenticate(stream: &mut TcpStream, user: &str, salt: &u32) -> Result<(), Error> {
|
||||
/**
|
||||
1. How do I store the lists of users and paswords? clear text or hash?? wtf
|
||||
2. Add auth to tests
|
||||
**/
|
||||
|
||||
let len = match stream.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => return Err(Error::AuthenticationError),
|
||||
};
|
||||
|
||||
// Read whatever is left.
|
||||
let mut password_hash = vec![0u8; len as usize - 4];
|
||||
|
||||
match stream.read_exact(&mut password_hash).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::AuthenticationError),
|
||||
};
|
||||
|
||||
let user_list = get_user_list();
|
||||
let mut password = match user_list.get(user) {
|
||||
Some(p) => p,
|
||||
None => return Err(Error::AuthenticationError),
|
||||
};
|
||||
|
||||
let mut md5 = Md5::new();
|
||||
|
||||
// concat('md5', md5(concat(md5(concat(password, username)), random-salt)))
|
||||
// First pass
|
||||
md5.update(&password.as_bytes());
|
||||
md5.update(&user.as_bytes());
|
||||
let output = md5.finalize_reset();
|
||||
// Second pass
|
||||
md5.update(format!("{:x}", output));
|
||||
md5.update(salt.to_be_bytes().to_vec());
|
||||
|
||||
|
||||
let password_string: String = String::from_utf8(password_hash).expect("Could not get password hash");
|
||||
match format!("md5{:x}", md5.finalize()) == password_string {
|
||||
true => Ok(()),
|
||||
_ => Err(Error::AuthenticationError)
|
||||
}
|
||||
}
|
||||
|
||||
/// Tell the client that authentication handshake completed successfully.
|
||||
pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
|
||||
let mut auth_ok = BytesMut::with_capacity(9);
|
||||
|
||||
4
src/userlist.json
Normal file
4
src/userlist.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"sven": "clear_text_password",
|
||||
"sharding_user": "sharding_user"
|
||||
}
|
||||
57
src/userlist.rs
Normal file
57
src/userlist.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use arc_swap::{ArcSwap, Guard};
|
||||
use log::{error};
|
||||
use once_cell::sync::Lazy;
|
||||
use tokio::fs::File;
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
use std::collections::{HashMap};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::errors::Error;
|
||||
|
||||
pub type UserList = HashMap<String, String>;
|
||||
static USER_LIST: Lazy<ArcSwap<UserList>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::new()));
|
||||
|
||||
pub fn get_user_list() -> Guard<Arc<UserList>> {
|
||||
USER_LIST.load()
|
||||
}
|
||||
|
||||
/// Parse the user list.
|
||||
pub async fn parse(path: &str) -> Result<(), Error> {
|
||||
let mut contents = String::new();
|
||||
let mut file = match File::open(path).await {
|
||||
Ok(file) => file,
|
||||
Err(err) => {
|
||||
error!("Could not open '{}': {}", path, err.to_string());
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
};
|
||||
|
||||
match file.read_to_string(&mut contents).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("Could not read config file: {}", err.to_string());
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
};
|
||||
|
||||
let map: HashMap<String, String> = serde_json::from_str(&contents).expect("JSON was not well-formatted");
|
||||
|
||||
|
||||
|
||||
USER_LIST.store(Arc::new(map.clone()));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_config() {
|
||||
parse("userlist.json").await.unwrap();
|
||||
assert_eq!(get_user_list()["sven"], "clear_text_password");
|
||||
assert_eq!(get_user_list()["sharding_user"], "sharding_user");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user