Adds details to errors and fixes error propagation bug (#239)

This commit is contained in:
zainkabani
2022-11-17 09:24:39 -08:00
committed by GitHub
parent fcd2cae4e1
commit c62b86f4e6
6 changed files with 94 additions and 48 deletions

View File

@@ -37,7 +37,10 @@ where
let code = query.get_u8() as char; let code = query.get_u8() as char;
if code != 'Q' { if code != 'Q' {
return Err(Error::ProtocolSyncError); return Err(Error::ProtocolSyncError(format!(
"Invalid code, expected 'Q' but got '{}'",
code
)));
} }
let len = query.get_i32() as usize; let len = query.get_i32() as usize;

View File

@@ -189,7 +189,12 @@ pub async fn client_entrypoint(
} }
// Client probably disconnected rejecting our plain text connection. // Client probably disconnected rejecting our plain text connection.
_ => Err(Error::ProtocolSyncError), Ok((ClientConnectionType::Tls, _))
| Ok((ClientConnectionType::CancelQuery, _)) => Err(Error::ProtocolSyncError(
format!("Bad postgres client (plain)"),
)),
Err(err) => Err(err),
} }
} }
} }
@@ -297,7 +302,10 @@ where
// Something else, probably something is wrong and it's not our fault, // Something else, probably something is wrong and it's not our fault,
// e.g. badly implemented Postgres client. // e.g. badly implemented Postgres client.
_ => Err(Error::ProtocolSyncError), _ => Err(Error::ProtocolSyncError(format!(
"Unexpected startup code: {}",
code
))),
} }
} }
@@ -343,7 +351,11 @@ pub async fn startup_tls(
} }
// Bad Postgres client. // Bad Postgres client.
_ => Err(Error::ProtocolSyncError), Ok((ClientConnectionType::Tls, _)) | Ok((ClientConnectionType::CancelQuery, _)) => Err(
Error::ProtocolSyncError(format!("Bad postgres client (tls)")),
),
Err(err) => Err(err),
} }
} }
@@ -373,7 +385,11 @@ where
// This parameter is mandatory by the protocol. // This parameter is mandatory by the protocol.
let username = match parameters.get("user") { let username = match parameters.get("user") {
Some(user) => user, Some(user) => user,
None => return Err(Error::ClientError), None => {
return Err(Error::ClientError(
"Missing user parameter on client startup".to_string(),
))
}
}; };
let pool_name = match parameters.get("database") { let pool_name = match parameters.get("database") {
@@ -416,25 +432,27 @@ where
let code = match read.read_u8().await { let code = match read.read_u8().await {
Ok(p) => p, Ok(p) => p,
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))),
}; };
// PasswordMessage // PasswordMessage
if code as char != 'p' { if code as char != 'p' {
debug!("Expected p, got {}", code as char); return Err(Error::ProtocolSyncError(format!(
return Err(Error::ProtocolSyncError); "Expected p, got {}",
code as char
)));
} }
let len = match read.read_i32().await { let len = match read.read_i32().await {
Ok(len) => len, Ok(len) => len,
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))),
}; };
let mut password_response = vec![0u8; (len - 4) as usize]; let mut password_response = vec![0u8; (len - 4) as usize];
match read.read_exact(&mut password_response).await { match read.read_exact(&mut password_response).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))),
}; };
// Authenticate admin user. // Authenticate admin user.
@@ -451,7 +469,7 @@ where
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name); warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name);
wrong_password(&mut write, username).await?; wrong_password(&mut write, username).await?;
return Err(Error::ClientError); return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name)));
} }
(false, generate_server_info_for_admin()) (false, generate_server_info_for_admin())
@@ -470,8 +488,7 @@ where
) )
.await?; .await?;
warn!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name); return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name)));
return Err(Error::ClientError);
} }
}; };
@@ -482,7 +499,7 @@ where
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name); warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name);
wrong_password(&mut write, username).await?; wrong_password(&mut write, username).await?;
return Err(Error::ClientError); return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name)));
} }
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction; let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
@@ -669,8 +686,7 @@ where
) )
.await?; .await?;
warn!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", self.pool_name, self.username, self.application_name); return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", self.pool_name, self.username, self.application_name)));
return Err(Error::ClientError);
} }
}; };
query_router.update_pool_settings(pool.settings.clone()); query_router.update_pool_settings(pool.settings.clone());

View File

@@ -3,13 +3,13 @@
/// Various errors. /// Various errors.
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub enum Error { pub enum Error {
SocketError, SocketError(String),
ClientBadStartup, ClientBadStartup,
ProtocolSyncError, ProtocolSyncError(String),
ServerError, ServerError,
BadConfig, BadConfig,
AllServersDown, AllServersDown,
ClientError, ClientError(String),
TlsError, TlsError,
StatementTimeout, StatementTimeout,
ShuttingDown, ShuttingDown,

View File

@@ -136,7 +136,11 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
match stream.write_all(&startup).await { match stream.write_all(&startup).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(_) => Err(Error::SocketError), Err(_) => {
return Err(Error::SocketError(format!(
"Error writing startup to server socket"
)))
}
} }
} }
@@ -450,7 +454,7 @@ where
{ {
match stream.write_all(&buf).await { match stream.write_all(&buf).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(_) => Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))),
} }
} }
@@ -461,7 +465,7 @@ where
{ {
match stream.write_all(&buf).await { match stream.write_all(&buf).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(_) => Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))),
} }
} }
@@ -472,19 +476,33 @@ where
{ {
let code = match stream.read_u8().await { let code = match stream.read_u8().await {
Ok(code) => code, Ok(code) => code,
Err(_) => return Err(Error::SocketError), Err(_) => {
return Err(Error::SocketError(format!(
"Error reading message code from socket"
)))
}
}; };
let len = match stream.read_i32().await { let len = match stream.read_i32().await {
Ok(len) => len, Ok(len) => len,
Err(_) => return Err(Error::SocketError), Err(_) => {
return Err(Error::SocketError(format!(
"Error reading message len from socket, code: {:?}",
code
)))
}
}; };
let mut buf = vec![0u8; len as usize - 4]; let mut buf = vec![0u8; len as usize - 4];
match stream.read_exact(&mut buf).await { match stream.read_exact(&mut buf).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError), Err(_) => {
return Err(Error::SocketError(format!(
"Error reading message from socket, code: {:?}",
code
)))
}
}; };
let mut bytes = BytesMut::with_capacity(len as usize + 1); let mut bytes = BytesMut::with_capacity(len as usize + 1);

View File

@@ -78,12 +78,12 @@ impl ScramSha256 {
let server_message = Message::parse(message)?; let server_message = Message::parse(message)?;
if !server_message.nonce.starts_with(&self.nonce) { if !server_message.nonce.starts_with(&self.nonce) {
return Err(Error::ProtocolSyncError); return Err(Error::ProtocolSyncError(format!("SCRAM")));
} }
let salt = match base64::decode(&server_message.salt) { let salt = match base64::decode(&server_message.salt) {
Ok(salt) => salt, Ok(salt) => salt,
Err(_) => return Err(Error::ProtocolSyncError), Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
}; };
let salted_password = Self::hi( let salted_password = Self::hi(
@@ -163,7 +163,7 @@ impl ScramSha256 {
let verifier = match base64::decode(&final_message.value) { let verifier = match base64::decode(&final_message.value) {
Ok(verifier) => verifier, Ok(verifier) => verifier,
Err(_) => return Err(Error::ProtocolSyncError), Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
}; };
let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) { let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
@@ -225,14 +225,14 @@ impl Message {
.collect::<Vec<String>>(); .collect::<Vec<String>>();
if parts.len() != 3 { if parts.len() != 3 {
return Err(Error::ProtocolSyncError); return Err(Error::ProtocolSyncError(format!("SCRAM")));
} }
let nonce = str::replace(&parts[0], "r=", ""); let nonce = str::replace(&parts[0], "r=", "");
let salt = str::replace(&parts[1], "s=", ""); let salt = str::replace(&parts[1], "s=", "");
let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() { let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
Ok(iterations) => iterations, Ok(iterations) => iterations,
Err(_) => return Err(Error::ProtocolSyncError), Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
}; };
Ok(Message { Ok(Message {
@@ -252,7 +252,7 @@ impl FinalMessage {
/// Parse the server final validation message. /// Parse the server final validation message.
pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> { pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
if !message.starts_with(b"v=") || message.len() < 4 { if !message.starts_with(b"v=") || message.len() < 4 {
return Err(Error::ProtocolSyncError); return Err(Error::ProtocolSyncError(format!("SCRAM")));
} }
Ok(FinalMessage { Ok(FinalMessage {

View File

@@ -86,7 +86,10 @@ impl Server {
Ok(stream) => stream, Ok(stream) => stream,
Err(err) => { Err(err) => {
error!("Could not connect to server: {}", err); error!("Could not connect to server: {}", err);
return Err(Error::SocketError); return Err(Error::SocketError(format!(
"Could not connect to server: {}",
err
)));
} }
}; };
@@ -106,12 +109,12 @@ impl Server {
loop { loop {
let code = match stream.read_u8().await { let code = match stream.read_u8().await {
Ok(code) => code as char, Ok(code) => code as char,
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error reading message code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
}; };
let len = match stream.read_i32().await { let len = match stream.read_i32().await {
Ok(len) => len, Ok(len) => len,
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error reading message len on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
}; };
trace!("Message: {}", code); trace!("Message: {}", code);
@@ -122,7 +125,7 @@ impl Server {
// Determine which kind of authentication is required, if any. // Determine which kind of authentication is required, if any.
let auth_code = match stream.read_i32().await { let auth_code = match stream.read_i32().await {
Ok(auth_code) => auth_code, Ok(auth_code) => auth_code,
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error reading auth code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
}; };
trace!("Auth: {}", auth_code); trace!("Auth: {}", auth_code);
@@ -135,7 +138,7 @@ impl Server {
match stream.read_exact(&mut salt).await { match stream.read_exact(&mut salt).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error reading salt on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
}; };
md5_password(&mut stream, &user.username, &user.password, &salt[..]) md5_password(&mut stream, &user.username, &user.password, &salt[..])
@@ -151,7 +154,7 @@ impl Server {
match stream.read_exact(&mut sasl_auth).await { match stream.read_exact(&mut sasl_auth).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error reading sasl message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
}; };
let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]); let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]);
@@ -193,7 +196,7 @@ impl Server {
match stream.read_exact(&mut sasl_data).await { match stream.read_exact(&mut sasl_data).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error reading sasl cont message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
}; };
let msg = BytesMut::from(&sasl_data[..]); let msg = BytesMut::from(&sasl_data[..]);
@@ -214,7 +217,7 @@ impl Server {
let mut sasl_final = vec![0u8; len as usize - 8]; let mut sasl_final = vec![0u8; len as usize - 8];
match stream.read_exact(&mut sasl_final).await { match stream.read_exact(&mut sasl_final).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error reading sasl final message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
}; };
match scram.finish(&BytesMut::from(&sasl_final[..])) { match scram.finish(&BytesMut::from(&sasl_final[..])) {
@@ -240,7 +243,7 @@ impl Server {
'E' => { 'E' => {
let error_code = match stream.read_u8().await { let error_code = match stream.read_u8().await {
Ok(error_code) => error_code, Ok(error_code) => error_code,
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error reading error code message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
}; };
trace!("Error: {}", error_code); trace!("Error: {}", error_code);
@@ -256,7 +259,7 @@ impl Server {
match stream.read_exact(&mut error).await { match stream.read_exact(&mut error).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error reading error message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
}; };
// TODO: the error message contains multiple fields; we can decode them and // TODO: the error message contains multiple fields; we can decode them and
@@ -275,7 +278,7 @@ impl Server {
match stream.read_exact(&mut param).await { match stream.read_exact(&mut param).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error reading parameter status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
}; };
// Save the parameter so we can pass it to the client later. // Save the parameter so we can pass it to the client later.
@@ -292,12 +295,12 @@ impl Server {
// See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>. // See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>.
process_id = match stream.read_i32().await { process_id = match stream.read_i32().await {
Ok(id) => id, Ok(id) => id,
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error reading process id message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
}; };
secret_key = match stream.read_i32().await { secret_key = match stream.read_i32().await {
Ok(id) => id, Ok(id) => id,
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error reading secret key message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
}; };
} }
@@ -307,7 +310,7 @@ impl Server {
match stream.read_exact(&mut idle).await { match stream.read_exact(&mut idle).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError(format!("Error reading transaction status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
}; };
let (read, write) = stream.into_split(); let (read, write) = stream.into_split();
@@ -341,7 +344,10 @@ impl Server {
// Means we implemented the protocol wrong or we're not talking to a Postgres server. // Means we implemented the protocol wrong or we're not talking to a Postgres server.
_ => { _ => {
error!("Unknown code: {}", code); error!("Unknown code: {}", code);
return Err(Error::ProtocolSyncError); return Err(Error::ProtocolSyncError(format!(
"Unknown server code: {}",
code
)));
} }
}; };
} }
@@ -359,7 +365,7 @@ impl Server {
Ok(stream) => stream, Ok(stream) => stream,
Err(err) => { Err(err) => {
error!("Could not connect to server: {}", err); error!("Could not connect to server: {}", err);
return Err(Error::SocketError); return Err(Error::SocketError(format!("Error reading cancel message")));
} }
}; };
@@ -438,7 +444,10 @@ impl Server {
// Something totally unexpected, this is not a Postgres server we know. // Something totally unexpected, this is not a Postgres server we know.
_ => { _ => {
self.bad = true; self.bad = true;
return Err(Error::ProtocolSyncError); return Err(Error::ProtocolSyncError(format!(
"Unknown transaction state: {}",
transaction_state
)));
} }
}; };