Adds details to errors and fixes error propagation bug (#239)

This commit is contained in:
zainkabani
2022-11-17 09:24:39 -08:00
committed by GitHub
parent fcd2cae4e1
commit c62b86f4e6
6 changed files with 94 additions and 48 deletions

View File

@@ -37,7 +37,10 @@ where
let code = query.get_u8() as char;
if code != 'Q' {
return Err(Error::ProtocolSyncError);
return Err(Error::ProtocolSyncError(format!(
"Invalid code, expected 'Q' but got '{}'",
code
)));
}
let len = query.get_i32() as usize;

View File

@@ -189,7 +189,12 @@ pub async fn client_entrypoint(
}
// Client probably disconnected rejecting our plain text connection.
_ => Err(Error::ProtocolSyncError),
Ok((ClientConnectionType::Tls, _))
| Ok((ClientConnectionType::CancelQuery, _)) => Err(Error::ProtocolSyncError(
format!("Bad postgres client (plain)"),
)),
Err(err) => Err(err),
}
}
}
@@ -297,7 +302,10 @@ where
// Something else, probably something is wrong and it's not our fault,
// e.g. badly implemented Postgres client.
_ => Err(Error::ProtocolSyncError),
_ => Err(Error::ProtocolSyncError(format!(
"Unexpected startup code: {}",
code
))),
}
}
@@ -343,7 +351,11 @@ pub async fn startup_tls(
}
// Bad Postgres client.
_ => Err(Error::ProtocolSyncError),
Ok((ClientConnectionType::Tls, _)) | Ok((ClientConnectionType::CancelQuery, _)) => Err(
Error::ProtocolSyncError(format!("Bad postgres client (tls)")),
),
Err(err) => Err(err),
}
}
@@ -373,7 +385,11 @@ where
// This parameter is mandatory by the protocol.
let username = match parameters.get("user") {
Some(user) => user,
None => return Err(Error::ClientError),
None => {
return Err(Error::ClientError(
"Missing user parameter on client startup".to_string(),
))
}
};
let pool_name = match parameters.get("database") {
@@ -416,25 +432,27 @@ where
let code = match read.read_u8().await {
Ok(p) => p,
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))),
};
// PasswordMessage
if code as char != 'p' {
debug!("Expected p, got {}", code as char);
return Err(Error::ProtocolSyncError);
return Err(Error::ProtocolSyncError(format!(
"Expected p, got {}",
code as char
)));
}
let len = match read.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))),
};
let mut password_response = vec![0u8; (len - 4) as usize];
match read.read_exact(&mut password_response).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))),
};
// Authenticate admin user.
@@ -451,7 +469,7 @@ where
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name);
wrong_password(&mut write, username).await?;
return Err(Error::ClientError);
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name)));
}
(false, generate_server_info_for_admin())
@@ -470,8 +488,7 @@ where
)
.await?;
warn!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name);
return Err(Error::ClientError);
return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name)));
}
};
@@ -482,7 +499,7 @@ where
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name);
wrong_password(&mut write, username).await?;
return Err(Error::ClientError);
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name)));
}
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
@@ -669,8 +686,7 @@ where
)
.await?;
warn!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", self.pool_name, self.username, self.application_name);
return Err(Error::ClientError);
return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", self.pool_name, self.username, self.application_name)));
}
};
query_router.update_pool_settings(pool.settings.clone());

View File

@@ -3,13 +3,13 @@
/// Various errors.
#[derive(Debug, PartialEq)]
pub enum Error {
SocketError,
SocketError(String),
ClientBadStartup,
ProtocolSyncError,
ProtocolSyncError(String),
ServerError,
BadConfig,
AllServersDown,
ClientError,
ClientError(String),
TlsError,
StatementTimeout,
ShuttingDown,

View File

@@ -136,7 +136,11 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
match stream.write_all(&startup).await {
Ok(_) => Ok(()),
Err(_) => Err(Error::SocketError),
Err(_) => {
return Err(Error::SocketError(format!(
"Error writing startup to server socket"
)))
}
}
}
@@ -450,7 +454,7 @@ where
{
match stream.write_all(&buf).await {
Ok(_) => Ok(()),
Err(_) => Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))),
}
}
@@ -461,7 +465,7 @@ where
{
match stream.write_all(&buf).await {
Ok(_) => Ok(()),
Err(_) => Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))),
}
}
@@ -472,19 +476,33 @@ where
{
let code = match stream.read_u8().await {
Ok(code) => code,
Err(_) => return Err(Error::SocketError),
Err(_) => {
return Err(Error::SocketError(format!(
"Error reading message code from socket"
)))
}
};
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError),
Err(_) => {
return Err(Error::SocketError(format!(
"Error reading message len from socket, code: {:?}",
code
)))
}
};
let mut buf = vec![0u8; len as usize - 4];
match stream.read_exact(&mut buf).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
Err(_) => {
return Err(Error::SocketError(format!(
"Error reading message from socket, code: {:?}",
code
)))
}
};
let mut bytes = BytesMut::with_capacity(len as usize + 1);

View File

@@ -78,12 +78,12 @@ impl ScramSha256 {
let server_message = Message::parse(message)?;
if !server_message.nonce.starts_with(&self.nonce) {
return Err(Error::ProtocolSyncError);
return Err(Error::ProtocolSyncError(format!("SCRAM")));
}
let salt = match base64::decode(&server_message.salt) {
Ok(salt) => salt,
Err(_) => return Err(Error::ProtocolSyncError),
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
};
let salted_password = Self::hi(
@@ -163,7 +163,7 @@ impl ScramSha256 {
let verifier = match base64::decode(&final_message.value) {
Ok(verifier) => verifier,
Err(_) => return Err(Error::ProtocolSyncError),
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
};
let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
@@ -225,14 +225,14 @@ impl Message {
.collect::<Vec<String>>();
if parts.len() != 3 {
return Err(Error::ProtocolSyncError);
return Err(Error::ProtocolSyncError(format!("SCRAM")));
}
let nonce = str::replace(&parts[0], "r=", "");
let salt = str::replace(&parts[1], "s=", "");
let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
Ok(iterations) => iterations,
Err(_) => return Err(Error::ProtocolSyncError),
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
};
Ok(Message {
@@ -252,7 +252,7 @@ impl FinalMessage {
/// Parse the server final validation message.
pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
if !message.starts_with(b"v=") || message.len() < 4 {
return Err(Error::ProtocolSyncError);
return Err(Error::ProtocolSyncError(format!("SCRAM")));
}
Ok(FinalMessage {

View File

@@ -86,7 +86,10 @@ impl Server {
Ok(stream) => stream,
Err(err) => {
error!("Could not connect to server: {}", err);
return Err(Error::SocketError);
return Err(Error::SocketError(format!(
"Could not connect to server: {}",
err
)));
}
};
@@ -106,12 +109,12 @@ impl Server {
loop {
let code = match stream.read_u8().await {
Ok(code) => code as char,
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading message code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading message len on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
trace!("Message: {}", code);
@@ -122,7 +125,7 @@ impl Server {
// Determine which kind of authentication is required, if any.
let auth_code = match stream.read_i32().await {
Ok(auth_code) => auth_code,
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading auth code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
trace!("Auth: {}", auth_code);
@@ -135,7 +138,7 @@ impl Server {
match stream.read_exact(&mut salt).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading salt on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
md5_password(&mut stream, &user.username, &user.password, &salt[..])
@@ -151,7 +154,7 @@ impl Server {
match stream.read_exact(&mut sasl_auth).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading sasl message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]);
@@ -193,7 +196,7 @@ impl Server {
match stream.read_exact(&mut sasl_data).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading sasl cont message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
let msg = BytesMut::from(&sasl_data[..]);
@@ -214,7 +217,7 @@ impl Server {
let mut sasl_final = vec![0u8; len as usize - 8];
match stream.read_exact(&mut sasl_final).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading sasl final message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
match scram.finish(&BytesMut::from(&sasl_final[..])) {
@@ -240,7 +243,7 @@ impl Server {
'E' => {
let error_code = match stream.read_u8().await {
Ok(error_code) => error_code,
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading error code message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
trace!("Error: {}", error_code);
@@ -256,7 +259,7 @@ impl Server {
match stream.read_exact(&mut error).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading error message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
// TODO: the error message contains multiple fields; we can decode them and
@@ -275,7 +278,7 @@ impl Server {
match stream.read_exact(&mut param).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading parameter status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
// Save the parameter so we can pass it to the client later.
@@ -292,12 +295,12 @@ impl Server {
// See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>.
process_id = match stream.read_i32().await {
Ok(id) => id,
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading process id message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
secret_key = match stream.read_i32().await {
Ok(id) => id,
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading secret key message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
}
@@ -307,7 +310,7 @@ impl Server {
match stream.read_exact(&mut idle).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError),
Err(_) => return Err(Error::SocketError(format!("Error reading transaction status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
};
let (read, write) = stream.into_split();
@@ -341,7 +344,10 @@ impl Server {
// Means we implemented the protocol wrong or we're not talking to a Postgres server.
_ => {
error!("Unknown code: {}", code);
return Err(Error::ProtocolSyncError);
return Err(Error::ProtocolSyncError(format!(
"Unknown server code: {}",
code
)));
}
};
}
@@ -359,7 +365,7 @@ impl Server {
Ok(stream) => stream,
Err(err) => {
error!("Could not connect to server: {}", err);
return Err(Error::SocketError);
return Err(Error::SocketError(format!("Error reading cancel message")));
}
};
@@ -438,7 +444,10 @@ impl Server {
// Something totally unexpected, this is not a Postgres server we know.
_ => {
self.bad = true;
return Err(Error::ProtocolSyncError);
return Err(Error::ProtocolSyncError(format!(
"Unknown transaction state: {}",
transaction_state
)));
}
};