mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
Adds details to errors and fixes error propagation bug (#239)
This commit is contained in:
@@ -37,7 +37,10 @@ where
|
||||
let code = query.get_u8() as char;
|
||||
|
||||
if code != 'Q' {
|
||||
return Err(Error::ProtocolSyncError);
|
||||
return Err(Error::ProtocolSyncError(format!(
|
||||
"Invalid code, expected 'Q' but got '{}'",
|
||||
code
|
||||
)));
|
||||
}
|
||||
|
||||
let len = query.get_i32() as usize;
|
||||
|
||||
@@ -189,7 +189,12 @@ pub async fn client_entrypoint(
|
||||
}
|
||||
|
||||
// Client probably disconnected rejecting our plain text connection.
|
||||
_ => Err(Error::ProtocolSyncError),
|
||||
Ok((ClientConnectionType::Tls, _))
|
||||
| Ok((ClientConnectionType::CancelQuery, _)) => Err(Error::ProtocolSyncError(
|
||||
format!("Bad postgres client (plain)"),
|
||||
)),
|
||||
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -297,7 +302,10 @@ where
|
||||
|
||||
// Something else, probably something is wrong and it's not our fault,
|
||||
// e.g. badly implemented Postgres client.
|
||||
_ => Err(Error::ProtocolSyncError),
|
||||
_ => Err(Error::ProtocolSyncError(format!(
|
||||
"Unexpected startup code: {}",
|
||||
code
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -343,7 +351,11 @@ pub async fn startup_tls(
|
||||
}
|
||||
|
||||
// Bad Postgres client.
|
||||
_ => Err(Error::ProtocolSyncError),
|
||||
Ok((ClientConnectionType::Tls, _)) | Ok((ClientConnectionType::CancelQuery, _)) => Err(
|
||||
Error::ProtocolSyncError(format!("Bad postgres client (tls)")),
|
||||
),
|
||||
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -373,7 +385,11 @@ where
|
||||
// This parameter is mandatory by the protocol.
|
||||
let username = match parameters.get("user") {
|
||||
Some(user) => user,
|
||||
None => return Err(Error::ClientError),
|
||||
None => {
|
||||
return Err(Error::ClientError(
|
||||
"Missing user parameter on client startup".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let pool_name = match parameters.get("database") {
|
||||
@@ -416,25 +432,27 @@ where
|
||||
|
||||
let code = match read.read_u8().await {
|
||||
Ok(p) => p,
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading password code from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))),
|
||||
};
|
||||
|
||||
// PasswordMessage
|
||||
if code as char != 'p' {
|
||||
debug!("Expected p, got {}", code as char);
|
||||
return Err(Error::ProtocolSyncError);
|
||||
return Err(Error::ProtocolSyncError(format!(
|
||||
"Expected p, got {}",
|
||||
code as char
|
||||
)));
|
||||
}
|
||||
|
||||
let len = match read.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading password message length from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))),
|
||||
};
|
||||
|
||||
let mut password_response = vec![0u8; (len - 4) as usize];
|
||||
|
||||
match read.read_exact(&mut password_response).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading password message from client {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name))),
|
||||
};
|
||||
|
||||
// Authenticate admin user.
|
||||
@@ -451,7 +469,7 @@ where
|
||||
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name);
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(Error::ClientError);
|
||||
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name)));
|
||||
}
|
||||
|
||||
(false, generate_server_info_for_admin())
|
||||
@@ -470,8 +488,7 @@ where
|
||||
)
|
||||
.await?;
|
||||
|
||||
warn!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name);
|
||||
return Err(Error::ClientError);
|
||||
return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name)));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -482,7 +499,7 @@ where
|
||||
warn!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name);
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(Error::ClientError);
|
||||
return Err(Error::ClientError(format!("Invalid password {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", pool_name, username, application_name)));
|
||||
}
|
||||
|
||||
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
|
||||
@@ -669,8 +686,7 @@ where
|
||||
)
|
||||
.await?;
|
||||
|
||||
warn!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", self.pool_name, self.username, self.application_name);
|
||||
return Err(Error::ClientError);
|
||||
return Err(Error::ClientError(format!("Invalid pool name {{ username: {:?}, pool_name: {:?}, application_name: {:?} }}", self.pool_name, self.username, self.application_name)));
|
||||
}
|
||||
};
|
||||
query_router.update_pool_settings(pool.settings.clone());
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
/// Various errors.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum Error {
|
||||
SocketError,
|
||||
SocketError(String),
|
||||
ClientBadStartup,
|
||||
ProtocolSyncError,
|
||||
ProtocolSyncError(String),
|
||||
ServerError,
|
||||
BadConfig,
|
||||
AllServersDown,
|
||||
ClientError,
|
||||
ClientError(String),
|
||||
TlsError,
|
||||
StatementTimeout,
|
||||
ShuttingDown,
|
||||
|
||||
@@ -136,7 +136,11 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
|
||||
|
||||
match stream.write_all(&startup).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => Err(Error::SocketError),
|
||||
Err(_) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing startup to server socket"
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -450,7 +454,7 @@ where
|
||||
{
|
||||
match stream.write_all(&buf).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -461,7 +465,7 @@ where
|
||||
{
|
||||
match stream.write_all(&buf).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -472,19 +476,33 @@ where
|
||||
{
|
||||
let code = match stream.read_u8().await {
|
||||
Ok(code) => code,
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error reading message code from socket"
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
let len = match stream.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error reading message len from socket, code: {:?}",
|
||||
code
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
let mut buf = vec![0u8; len as usize - 4];
|
||||
|
||||
match stream.read_exact(&mut buf).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error reading message from socket, code: {:?}",
|
||||
code
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
let mut bytes = BytesMut::with_capacity(len as usize + 1);
|
||||
|
||||
12
src/scram.rs
12
src/scram.rs
@@ -78,12 +78,12 @@ impl ScramSha256 {
|
||||
let server_message = Message::parse(message)?;
|
||||
|
||||
if !server_message.nonce.starts_with(&self.nonce) {
|
||||
return Err(Error::ProtocolSyncError);
|
||||
return Err(Error::ProtocolSyncError(format!("SCRAM")));
|
||||
}
|
||||
|
||||
let salt = match base64::decode(&server_message.salt) {
|
||||
Ok(salt) => salt,
|
||||
Err(_) => return Err(Error::ProtocolSyncError),
|
||||
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
|
||||
};
|
||||
|
||||
let salted_password = Self::hi(
|
||||
@@ -163,7 +163,7 @@ impl ScramSha256 {
|
||||
|
||||
let verifier = match base64::decode(&final_message.value) {
|
||||
Ok(verifier) => verifier,
|
||||
Err(_) => return Err(Error::ProtocolSyncError),
|
||||
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
|
||||
};
|
||||
|
||||
let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
|
||||
@@ -225,14 +225,14 @@ impl Message {
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
if parts.len() != 3 {
|
||||
return Err(Error::ProtocolSyncError);
|
||||
return Err(Error::ProtocolSyncError(format!("SCRAM")));
|
||||
}
|
||||
|
||||
let nonce = str::replace(&parts[0], "r=", "");
|
||||
let salt = str::replace(&parts[1], "s=", "");
|
||||
let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
|
||||
Ok(iterations) => iterations,
|
||||
Err(_) => return Err(Error::ProtocolSyncError),
|
||||
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))),
|
||||
};
|
||||
|
||||
Ok(Message {
|
||||
@@ -252,7 +252,7 @@ impl FinalMessage {
|
||||
/// Parse the server final validation message.
|
||||
pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
|
||||
if !message.starts_with(b"v=") || message.len() < 4 {
|
||||
return Err(Error::ProtocolSyncError);
|
||||
return Err(Error::ProtocolSyncError(format!("SCRAM")));
|
||||
}
|
||||
|
||||
Ok(FinalMessage {
|
||||
|
||||
@@ -86,7 +86,10 @@ impl Server {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
error!("Could not connect to server: {}", err);
|
||||
return Err(Error::SocketError);
|
||||
return Err(Error::SocketError(format!(
|
||||
"Could not connect to server: {}",
|
||||
err
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -106,12 +109,12 @@ impl Server {
|
||||
loop {
|
||||
let code = match stream.read_u8().await {
|
||||
Ok(code) => code as char,
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading message code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
};
|
||||
|
||||
let len = match stream.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading message len on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
};
|
||||
|
||||
trace!("Message: {}", code);
|
||||
@@ -122,7 +125,7 @@ impl Server {
|
||||
// Determine which kind of authentication is required, if any.
|
||||
let auth_code = match stream.read_i32().await {
|
||||
Ok(auth_code) => auth_code,
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading auth code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
};
|
||||
|
||||
trace!("Auth: {}", auth_code);
|
||||
@@ -135,7 +138,7 @@ impl Server {
|
||||
|
||||
match stream.read_exact(&mut salt).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading salt on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
};
|
||||
|
||||
md5_password(&mut stream, &user.username, &user.password, &salt[..])
|
||||
@@ -151,7 +154,7 @@ impl Server {
|
||||
|
||||
match stream.read_exact(&mut sasl_auth).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading sasl message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
};
|
||||
|
||||
let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]);
|
||||
@@ -193,7 +196,7 @@ impl Server {
|
||||
|
||||
match stream.read_exact(&mut sasl_data).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading sasl cont message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
};
|
||||
|
||||
let msg = BytesMut::from(&sasl_data[..]);
|
||||
@@ -214,7 +217,7 @@ impl Server {
|
||||
let mut sasl_final = vec![0u8; len as usize - 8];
|
||||
match stream.read_exact(&mut sasl_final).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading sasl final message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
};
|
||||
|
||||
match scram.finish(&BytesMut::from(&sasl_final[..])) {
|
||||
@@ -240,7 +243,7 @@ impl Server {
|
||||
'E' => {
|
||||
let error_code = match stream.read_u8().await {
|
||||
Ok(error_code) => error_code,
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading error code message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
};
|
||||
|
||||
trace!("Error: {}", error_code);
|
||||
@@ -256,7 +259,7 @@ impl Server {
|
||||
|
||||
match stream.read_exact(&mut error).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading error message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
};
|
||||
|
||||
// TODO: the error message contains multiple fields; we can decode them and
|
||||
@@ -275,7 +278,7 @@ impl Server {
|
||||
|
||||
match stream.read_exact(&mut param).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading parameter status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
};
|
||||
|
||||
// Save the parameter so we can pass it to the client later.
|
||||
@@ -292,12 +295,12 @@ impl Server {
|
||||
// See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>.
|
||||
process_id = match stream.read_i32().await {
|
||||
Ok(id) => id,
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading process id message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
};
|
||||
|
||||
secret_key = match stream.read_i32().await {
|
||||
Ok(id) => id,
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading secret key message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
};
|
||||
}
|
||||
|
||||
@@ -307,7 +310,7 @@ impl Server {
|
||||
|
||||
match stream.read_exact(&mut idle).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::SocketError),
|
||||
Err(_) => return Err(Error::SocketError(format!("Error reading transaction status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
|
||||
};
|
||||
|
||||
let (read, write) = stream.into_split();
|
||||
@@ -341,7 +344,10 @@ impl Server {
|
||||
// Means we implemented the protocol wrong or we're not talking to a Postgres server.
|
||||
_ => {
|
||||
error!("Unknown code: {}", code);
|
||||
return Err(Error::ProtocolSyncError);
|
||||
return Err(Error::ProtocolSyncError(format!(
|
||||
"Unknown server code: {}",
|
||||
code
|
||||
)));
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -359,7 +365,7 @@ impl Server {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
error!("Could not connect to server: {}", err);
|
||||
return Err(Error::SocketError);
|
||||
return Err(Error::SocketError(format!("Error reading cancel message")));
|
||||
}
|
||||
};
|
||||
|
||||
@@ -438,7 +444,10 @@ impl Server {
|
||||
// Something totally unexpected, this is not a Postgres server we know.
|
||||
_ => {
|
||||
self.bad = true;
|
||||
return Err(Error::ProtocolSyncError);
|
||||
return Err(Error::ProtocolSyncError(format!(
|
||||
"Unknown transaction state: {}",
|
||||
transaction_state
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user