diff --git a/src/admin.rs b/src/admin.rs index d4979fd..4460f98 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -37,7 +37,10 @@ where let code = query.get_u8() as char; if code != 'Q' { - return Err(Error::ProtocolSyncError); + return Err(Error::ProtocolSyncError(format!( + "Invalid code, expected 'Q' but got '{}'", + code + ))); } let len = query.get_i32() as usize; diff --git a/src/client.rs b/src/client.rs index 7bdb497..b55906b 100644 --- a/src/client.rs +++ b/src/client.rs @@ -189,7 +189,12 @@ pub async fn client_entrypoint( } // Client probably disconnected rejecting our plain text connection. - _ => Err(Error::ProtocolSyncError), + Ok((ClientConnectionType::Tls, _)) + | Ok((ClientConnectionType::CancelQuery, _)) => Err(Error::ProtocolSyncError( + format!("Bad postgres client (plain)"), + )), + + Err(err) => Err(err), } } } @@ -297,7 +302,10 @@ where // Something else, probably something is wrong and it's not our fault, // e.g. badly implemented Postgres client. - _ => Err(Error::ProtocolSyncError), + _ => Err(Error::ProtocolSyncError(format!( + "Unexpected startup code: {}", + code + ))), } } @@ -343,7 +351,11 @@ pub async fn startup_tls( } // Bad Postgres client. - _ => Err(Error::ProtocolSyncError), + Ok((ClientConnectionType::Tls, _)) | Ok((ClientConnectionType::CancelQuery, _)) => Err( + Error::ProtocolSyncError(format!("Bad postgres client (tls)")), + ), + + Err(err) => Err(err), } } @@ -373,7 +385,11 @@ where // This parameter is mandatory by the protocol. let username = match parameters.get("user") { Some(user) => user, - None => return Err(Error::ClientError), + None => { + return Err(Error::ClientError( + "Missing user parameter on client startup".to_string(), + )) + } }; let pool_name = match parameters.get("database") { @@ -416,25 +432,27 @@ where let code = match read.read_u8().await { Ok(p) => p, - Err(_) => return Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))), }; // PasswordMessage if code as char != 'p' { - debug!("Expected p, got {}", code as char); - return Err(Error::ProtocolSyncError); + return Err(Error::ProtocolSyncError(format!( + "Expected p, got {}", + code as char + ))); } let len = match read.read_i32().await { Ok(len) => len, - Err(_) => return Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))), }; let mut password_response = vec![0u8; (len - 4) as usize]; match read.read_exact(&mut password_response).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))), }; // Authenticate admin user. @@ -451,7 +469,7 @@ where warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name); wrong_password(&mut write, username).await?; - return Err(Error::ClientError); + return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))); } (false, generate_server_info_for_admin()) @@ -470,8 +488,7 @@ where ) .await?; - warn!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name); - return Err(Error::ClientError); + return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))); } }; @@ -482,7 +499,7 @@ where warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name); wrong_password(&mut write, username).await?; - return Err(Error::ClientError); + return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))); } let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction; @@ -669,8 +686,7 @@ where ) .await?; - warn!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", self.pool_name, self.username, self.application_name); - return Err(Error::ClientError); + return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", self.pool_name, self.username, self.application_name))); } }; query_router.update_pool_settings(pool.settings.clone()); diff --git a/src/errors.rs b/src/errors.rs index 50301f3..7789a8a 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -3,13 +3,13 @@ /// Various errors. #[derive(Debug, PartialEq)] pub enum Error { - SocketError, + SocketError(String), ClientBadStartup, - ProtocolSyncError, + ProtocolSyncError(String), ServerError, BadConfig, AllServersDown, - ClientError, + ClientError(String), TlsError, StatementTimeout, ShuttingDown, diff --git a/src/messages.rs b/src/messages.rs index 0d7bc57..826508e 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -136,7 +136,11 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu match stream.write_all(&startup).await { Ok(_) => Ok(()), - Err(_) => Err(Error::SocketError), + Err(_) => { + return Err(Error::SocketError(format!( + "Error writing startup to server socket" + ))) + } } } @@ -450,7 +454,7 @@ where { match stream.write_all(&buf).await { Ok(_) => Ok(()), - Err(_) => Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))), } } @@ -461,7 +465,7 @@ where { match stream.write_all(&buf).await { Ok(_) => Ok(()), - Err(_) => Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))), } } @@ -472,19 +476,33 @@ where { let code = match stream.read_u8().await { Ok(code) => code, - Err(_) => return Err(Error::SocketError), + Err(_) => { + return Err(Error::SocketError(format!( + "Error reading message code from socket" + ))) + } }; let len = match stream.read_i32().await { Ok(len) => len, - Err(_) => return Err(Error::SocketError), + Err(_) => { + return Err(Error::SocketError(format!( + "Error reading message len from socket, code: {:?}", + code + ))) + } }; let mut buf = vec![0u8; len as usize - 4]; match stream.read_exact(&mut buf).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError), + Err(_) => { + return Err(Error::SocketError(format!( + "Error reading message from socket, code: {:?}", + code + ))) + } }; let mut bytes = BytesMut::with_capacity(len as usize + 1); diff --git a/src/scram.rs b/src/scram.rs index 8c89f95..c3f920d 100644 --- a/src/scram.rs +++ b/src/scram.rs @@ -78,12 +78,12 @@ impl ScramSha256 { let server_message = Message::parse(message)?; if !server_message.nonce.starts_with(&self.nonce) { - return Err(Error::ProtocolSyncError); + return Err(Error::ProtocolSyncError(format!("SCRAM"))); } let salt = match base64::decode(&server_message.salt) { Ok(salt) => salt, - Err(_) => return Err(Error::ProtocolSyncError), + Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), }; let salted_password = Self::hi( @@ -163,7 +163,7 @@ impl ScramSha256 { let verifier = match base64::decode(&final_message.value) { Ok(verifier) => verifier, - Err(_) => return Err(Error::ProtocolSyncError), + Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), }; let mut hmac = match Hmac::::new_from_slice(&self.salted_password) { @@ -225,14 +225,14 @@ impl Message { .collect::>(); if parts.len() != 3 { - return Err(Error::ProtocolSyncError); + return Err(Error::ProtocolSyncError(format!("SCRAM"))); } let nonce = str::replace(&parts[0], "r=", ""); let salt = str::replace(&parts[1], "s=", ""); let iterations = match str::replace(&parts[2], "i=", "").parse::() { Ok(iterations) => iterations, - Err(_) => return Err(Error::ProtocolSyncError), + Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), }; Ok(Message { @@ -252,7 +252,7 @@ impl FinalMessage { /// Parse the server final validation message. pub fn parse(message: &BytesMut) -> Result { if !message.starts_with(b"v=") || message.len() < 4 { - return Err(Error::ProtocolSyncError); + return Err(Error::ProtocolSyncError(format!("SCRAM"))); } Ok(FinalMessage { diff --git a/src/server.rs b/src/server.rs index 0d8f48d..65fb8d9 100644 --- a/src/server.rs +++ b/src/server.rs @@ -86,7 +86,10 @@ impl Server { Ok(stream) => stream, Err(err) => { error!("Could not connect to server: {}", err); - return Err(Error::SocketError); + return Err(Error::SocketError(format!( + "Could not connect to server: {}", + err + ))); } }; @@ -106,12 +109,12 @@ impl Server { loop { let code = match stream.read_u8().await { Ok(code) => code as char, - Err(_) => return Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error reading message code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; let len = match stream.read_i32().await { Ok(len) => len, - Err(_) => return Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error reading message len on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; trace!("Message: {}", code); @@ -122,7 +125,7 @@ impl Server { // Determine which kind of authentication is required, if any. let auth_code = match stream.read_i32().await { Ok(auth_code) => auth_code, - Err(_) => return Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error reading auth code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; trace!("Auth: {}", auth_code); @@ -135,7 +138,7 @@ impl Server { match stream.read_exact(&mut salt).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error reading salt on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; md5_password(&mut stream, &user.username, &user.password, &salt[..]) @@ -151,7 +154,7 @@ impl Server { match stream.read_exact(&mut sasl_auth).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error reading sasl message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]); @@ -193,7 +196,7 @@ impl Server { match stream.read_exact(&mut sasl_data).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error reading sasl cont message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; let msg = BytesMut::from(&sasl_data[..]); @@ -214,7 +217,7 @@ impl Server { let mut sasl_final = vec![0u8; len as usize - 8]; match stream.read_exact(&mut sasl_final).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error reading sasl final message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; match scram.finish(&BytesMut::from(&sasl_final[..])) { @@ -240,7 +243,7 @@ impl Server { 'E' => { let error_code = match stream.read_u8().await { Ok(error_code) => error_code, - Err(_) => return Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error reading error code message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; trace!("Error: {}", error_code); @@ -256,7 +259,7 @@ impl Server { match stream.read_exact(&mut error).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error reading error message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; // TODO: the error message contains multiple fields; we can decode them and @@ -275,7 +278,7 @@ impl Server { match stream.read_exact(&mut param).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error reading parameter status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; // Save the parameter so we can pass it to the client later. @@ -292,12 +295,12 @@ impl Server { // See: . process_id = match stream.read_i32().await { Ok(id) => id, - Err(_) => return Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error reading process id message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; secret_key = match stream.read_i32().await { Ok(id) => id, - Err(_) => return Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error reading secret key message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; } @@ -307,7 +310,7 @@ impl Server { match stream.read_exact(&mut idle).await { Ok(_) => (), - Err(_) => return Err(Error::SocketError), + Err(_) => return Err(Error::SocketError(format!("Error reading transaction status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))), }; let (read, write) = stream.into_split(); @@ -341,7 +344,10 @@ impl Server { // Means we implemented the protocol wrong or we're not talking to a Postgres server. _ => { error!("Unknown code: {}", code); - return Err(Error::ProtocolSyncError); + return Err(Error::ProtocolSyncError(format!( + "Unknown server code: {}", + code + ))); } }; } @@ -359,7 +365,7 @@ impl Server { Ok(stream) => stream, Err(err) => { error!("Could not connect to server: {}", err); - return Err(Error::SocketError); + return Err(Error::SocketError(format!("Error reading cancel message"))); } }; @@ -438,7 +444,10 @@ impl Server { // Something totally unexpected, this is not a Postgres server we know. _ => { self.bad = true; - return Err(Error::ProtocolSyncError); + return Err(Error::ProtocolSyncError(format!( + "Unknown transaction state: {}", + transaction_state + ))); } };