diff --git a/.gitignore b/.gitignore index 50ccb25..3e695c0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ /target *.deb +.idea/* +tests/ruby/.bundle/* +tests/ruby/vendor/* \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index d6e42cb..01eee28 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -220,6 +220,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "itoa" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35" + [[package]] name = "libc" version = "0.2.117" @@ -369,6 +375,7 @@ dependencies = [ "regex", "serde", "serde_derive", + "serde_json", "sha-1", "sqlparser", "statsd", @@ -478,6 +485,12 @@ version = "0.6.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b" +[[package]] +name = "ryu" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f" + [[package]] name = "scopeguard" version = "1.1.0" @@ -501,6 +514,17 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_json" +version = "1.0.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e8d9fa5c3b304765ce1fd9c4c8a3de2c8db365a5b91be52f186efc675681d95" +dependencies = [ + "itoa", + "ryu", + "serde", +] + [[package]] name = "sha-1" version = "0.10.0" diff --git a/Cargo.toml b/Cargo.toml index d070c61..4131269 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ sha-1 = "0.10" toml = "0.5" serde = "1" serde_derive = "1" +serde_json = "1" regex = "1" num_cpus = "1" once_cell = "1" diff --git a/pgcat.toml b/pgcat.toml index 2131802..ef0184c 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -48,8 +48,8 @@ password = "sharding_user" # [ host, port, role ] servers = [ - [ "127.0.0.1", 5432, "primary" ], - [ "localhost", 5432, "replica" ], + ["127.0.0.1", 5432, "primary"], + ["localhost", 5432, "replica"], # [ "127.0.1.1", 5432, "replica" ], ] # Database name (e.g. "postgres") @@ -58,8 +58,8 @@ database = "shard0" [shards.1] # [ host, port, role ] servers = [ - [ "127.0.0.1", 5432, "primary" ], - [ "localhost", 5432, "replica" ], + ["127.0.0.1", 5432, "primary"], + ["localhost", 5432, "replica"], # [ "127.0.1.1", 5432, "replica" ], ] database = "shard1" @@ -67,8 +67,8 @@ database = "shard1" [shards.2] # [ host, port, role ] servers = [ - [ "127.0.0.1", 5432, "primary" ], - [ "localhost", 5432, "replica" ], + ["127.0.0.1", 5432, "primary"], + ["localhost", 5432, "replica"], # [ "127.0.1.1", 5432, "replica" ], ] database = "shard2" diff --git a/src/client.rs b/src/client.rs index f73ef22..2c382df 100644 --- a/src/client.rs +++ b/src/client.rs @@ -108,15 +108,18 @@ impl Client { // Regular startup message. PROTOCOL_VERSION_NUMBER => { trace!("Got StartupMessage"); - - // TODO: perform actual auth. let parameters = parse_startup(bytes.clone())?; + let mut user_name: String = String::new(); + match parameters.get(&"user") { + Some(&user) => user_name = user, + None => return Err(Error::ClientBadStartup), + } + start_auth(&mut stream, &user_name).await?; // Generate random backend ID and secret key let process_id: i32 = rand::random(); let secret_key: i32 = rand::random(); - auth_ok(&mut stream).await?; write_all(&mut stream, server_info).await?; backend_key_data(&mut stream, process_id, secret_key).await?; ready_for_query(&mut stream).await?; diff --git a/src/errors.rs b/src/errors.rs index 1fc26bb..88cd6bd 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -8,5 +8,7 @@ pub enum Error { // ServerTimeout, // DirtyServer, BadConfig, + BadUserList, AllServersDown, + AuthenticationError } diff --git a/src/main.rs b/src/main.rs index a41d5ec..abc49b4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -50,6 +50,7 @@ use std::sync::Arc; mod admin; mod client; mod config; +mod userlist; mod constants; mod errors; mod messages; @@ -94,6 +95,15 @@ async fn main() { } }; + // Prepare user list + match userlist::parse("userlist.json").await { + Ok(_) => (), + Err(err) => { + error!("Userlist parse error: {:?}", err); + return; + } + }; + let config = get_config(); let addr = format!("{}:{}", config.general.host, config.general.port); diff --git a/src/messages.rs b/src/messages.rs index 473c8de..3a1e0bb 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -1,5 +1,7 @@ /// Helper functions to send one-off protocol messages /// and handle TcpStream (TCP socket). + + use bytes::{Buf, BufMut, BytesMut}; use md5::{Digest, Md5}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; @@ -7,10 +9,16 @@ use tokio::net::{ tcp::{OwnedReadHalf, OwnedWriteHalf}, TcpStream, }; +use log::{error}; use crate::errors::Error; use std::collections::HashMap; +use rand::Rng; + +use crate::userlist::get_user_list; + + /// Postgres data type mappings /// used in RowDescription ('T') message. pub enum DataType { @@ -29,6 +37,98 @@ impl From<&DataType> for i32 { } } +/** +1. Generate salt (4 bytes of random data) +md5(concat(md5(concat(password, username)), random-salt))) +2. Send md5 auth request +3. recieve PasswordMessage with salt. +4. refactor md5_password function to be reusable +5. check username hash combo against file +6. AuthenticationOk or ErrorResponse + **/ +pub async fn start_auth(stream: &mut TcpStream, user_name: &String) -> Result<(), Error> { + let mut rng = rand::thread_rng(); + + //Generate random 4 byte salt + let salt = rng.gen::(); + + // Send AuthenticationMD5Password request + send_md5_request(stream, salt).await?; + + let code = match stream.read_u8().await { + Ok(code) => code as char, + Err(_) => return Err(Error::AuthenticationError), + }; + + match code { + // Password response + 'p' => { + fetch_password_and_authenticate(stream, &user_name, &salt).await?; + Ok(auth_ok(stream).await?) + } + _ => { + error!("Unknown code: {}", code); + return Err(Error::AuthenticationError); + } + } +} + +pub async fn send_md5_request(stream: &mut TcpStream, salt: u32) -> Result<(), Error> { + let mut authentication_md5password = BytesMut::with_capacity(12); + authentication_md5password.put_u8(b'R'); + authentication_md5password.put_i32(12); + authentication_md5password.put_i32(5); + authentication_md5password.put_u32(salt); + + // Send AuthenticationMD5Password request + Ok(write_all(stream, authentication_md5password).await?) +} + +pub async fn fetch_password_and_authenticate(stream: &mut TcpStream, user_name: &String, salt: &u32) -> Result<(), Error> { + /** + 1. How do I store the lists of users and paswords? clear text or hash?? wtf + 2. Add auth to tests + **/ + + let len = match stream.read_i32().await { + Ok(len) => len, + Err(_) => return Err(Error::AuthenticationError), + }; + + // Read whatever is left. + let mut password_hash = vec![0u8; len as usize - 4]; + + match stream.read_exact(&mut password_hash).await { + Ok(_) => (), + Err(_) => return Err(Error::AuthenticationError), + }; + + let user_list = get_user_list(); + let mut password: String = String::new(); + match user_list.get(&user_name) { + Some(&p) => password = p, + None => return Err(Error::AuthenticationError), + } + + let mut md5 = Md5::new(); + + // concat('md5', md5(concat(md5(concat(password, username)), random-salt))) + // First pass + md5.update(&password.as_bytes()); + md5.update(&user_name.as_bytes()); + let output = md5.finalize_reset(); + // Second pass + md5.update(format!("{:x}", output)); + md5.update(salt.to_be_bytes().to_vec()); + + + let password_string: String = String::from_utf8(password_hash).expect("Could not get password hash"); + match format!("md5{:x}", md5.finalize()) == password_string { + true => Ok(()), + _ => Err(Error::AuthenticationError) + } +} + /// Tell the client that authentication handshake completed successfully. pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> { let mut auth_ok = BytesMut::with_capacity(9); diff --git a/src/userlist.json b/src/userlist.json new file mode 100644 index 0000000..2b4c8a2 --- /dev/null +++ b/src/userlist.json @@ -0,0 +1,4 @@ +{ + "sven": "clear_text_password", + "sharding_user": "sharding_user" +} \ No newline at end of file diff --git a/src/userlist.rs b/src/userlist.rs new file mode 100644 index 0000000..167d917 --- /dev/null +++ b/src/userlist.rs @@ -0,0 +1,57 @@ +use arc_swap::{ArcSwap, Guard}; +use log::{error}; +use once_cell::sync::Lazy; +use tokio::fs::File; +use tokio::io::AsyncReadExt; + +use std::collections::{HashMap}; +use std::sync::Arc; + +use crate::errors::Error; + +pub type UserList = HashMap; +static USER_LIST: Lazy> = Lazy::new(|| ArcSwap::from_pointee(HashMap::new())); + +pub fn get_user_list() -> Guard> { + USER_LIST.load() +} + +/// Parse the user list. +pub async fn parse(path: &str) -> Result<(), Error> { + let mut contents = String::new(); + let mut file = match File::open(path).await { + Ok(file) => file, + Err(err) => { + error!("Could not open '{}': {}", path, err.to_string()); + return Err(Error::BadConfig); + } + }; + + match file.read_to_string(&mut contents).await { + Ok(_) => (), + Err(err) => { + error!("Could not read config file: {}", err.to_string()); + return Err(Error::BadConfig); + } + }; + + let map: HashMap = serde_json::from_str(&contents).expect("JSON was not well-formatted"); + + + + USER_LIST.store(Arc::new(map.clone())); + + Ok(()) +} + +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + async fn test_config() { + parse("userlist.json").await.unwrap(); + assert_eq!(get_user_list()["sven"], "clear_text_password"); + assert_eq!(get_user_list()["sharding_user"], "sharding_user"); + } +}