From 9c521f07c1ff05ee971346af7f8bfeff13ba3deb Mon Sep 17 00:00:00 2001 From: Lev Kokotov Date: Mon, 14 Feb 2022 05:11:53 -0800 Subject: [PATCH] parse startup client parameters (#16) --- src/client.rs | 9 ++++++++- src/messages.rs | 49 ++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 2 deletions(-) diff --git a/src/client.rs b/src/client.rs index 154716c..43b3e63 100644 --- a/src/client.rs +++ b/src/client.rs @@ -8,6 +8,8 @@ use tokio::io::{AsyncReadExt, BufReader}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; +use std::collections::HashMap; + use crate::config::Role; use crate::errors::Error; use crate::messages::*; @@ -52,6 +54,9 @@ pub struct Client { // Unless client specifies, route queries to the servers that have this role, // e.g. primary or replicas or any. default_server_role: Option, + + // Client parameters, e.g. user, client_encoding, etc. + parameters: HashMap, } impl Client { @@ -96,7 +101,7 @@ impl Client { // Regular startup message. 196608 => { // TODO: perform actual auth. - // TODO: record startup parameters client sends over. + let parameters = parse_startup(bytes.clone())?; // Generate random backend ID and secret key let process_id: i32 = rand::random(); @@ -121,6 +126,7 @@ impl Client { secret_key: secret_key, client_server_map: client_server_map, default_server_role: default_server_role, + parameters: parameters, }); } @@ -141,6 +147,7 @@ impl Client { secret_key: secret_key, client_server_map: client_server_map, default_server_role: default_server_role, + parameters: HashMap::new(), }); } diff --git a/src/messages.rs b/src/messages.rs index cd99de7..beb6505 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -1,9 +1,11 @@ -use bytes::{BufMut, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use md5::{Digest, Md5}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::TcpStream; +use std::collections::HashMap; + use crate::errors::Error; // This is a funny one. `psql` parses this to figure out which @@ -105,6 +107,51 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu } } +/// Parse StartupMessage parameters. +/// e.g. user, database, application_name, etc. +pub fn parse_startup(mut bytes: BytesMut) -> Result, Error> { + let mut result = HashMap::new(); + let mut buf = Vec::new(); + let mut tmp = String::new(); + + while bytes.has_remaining() { + let mut c = bytes.get_u8(); + + // Null-terminated C-strings. + while c != 0 { + tmp.push(c as char); + c = bytes.get_u8(); + } + + if tmp.len() > 0 { + buf.push(tmp.clone()); + tmp.clear(); + } + } + + // Expect pairs of name and value + // and at least one pair to be present. + if buf.len() % 2 != 0 && buf.len() >= 2 { + return Err(Error::ClientBadStartup); + } + + let mut i = 0; + while i < buf.len() { + let name = buf[i].clone(); + let value = buf[i + 1].clone(); + let _ = result.insert(name, value); + i += 2; + } + + // Minimum required parameters + // I want to have the user at the very minimum, according to the protocol spec. + if !result.contains_key("user") { + return Err(Error::ClientBadStartup); + } + + Ok(result) +} + /// Send password challenge response to the server. /// This is the MD5 challenge. pub async fn md5_password(