A couple things (#397)

* Format cleanup

* fmt

* finally
This commit is contained in:
Lev Kokotov
2023-04-10 14:51:01 -07:00
committed by GitHub
parent a62f6b0eea
commit 692353c839
13 changed files with 366 additions and 100 deletions

View File

@@ -17,7 +17,7 @@ use tokio::net::{
use crate::config::{Address, User};
use crate::constants::*;
use crate::errors::Error;
use crate::errors::{Error, ServerIdentifier};
use crate::messages::*;
use crate::mirrors::MirroringManager;
use crate::pool::ClientServerMap;
@@ -108,6 +108,7 @@ impl Server {
let mut server_info = BytesMut::new();
let mut process_id: i32 = 0;
let mut secret_key: i32 = 0;
let server_identifier = ServerIdentifier::new(&user.username, &database);
// We'll be handling multiple packets, but they will all be structured the same.
// We'll loop here until this exchange is complete.
@@ -119,12 +120,22 @@ impl Server {
loop {
let code = match stream.read_u8().await {
Ok(code) => code as char,
Err(_) => return Err(Error::SocketError(format!("Error reading message code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"message code".into(),
server_identifier,
))
}
};
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::SocketError(format!("Error reading message len on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"message len".into(),
server_identifier,
))
}
};
trace!("Message: {}", code);
@@ -135,7 +146,12 @@ impl Server {
// Determine which kind of authentication is required, if any.
let auth_code = match stream.read_i32().await {
Ok(auth_code) => auth_code,
Err(_) => return Err(Error::SocketError(format!("Error reading auth code on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"auth code".into(),
server_identifier,
))
}
};
trace!("Auth: {}", auth_code);
@@ -148,7 +164,12 @@ impl Server {
match stream.read_exact(&mut salt).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading salt on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"salt".into(),
server_identifier,
))
}
};
match &user.password {
@@ -171,8 +192,12 @@ impl Server {
&salt[..],
)
.await?,
None =>
return Err(Error::AuthError(format!("Auth passthrough (auth_query) failed and no user password is set in cleartext for {{ username: {:?}, database: {:?} }}", user.username, database)))
None => return Err(
Error::ServerAuthError(
"Auth passthrough (auth_query) failed and no user password is set in cleartext".into(),
server_identifier
)
),
}
}
}
@@ -182,16 +207,28 @@ impl Server {
SASL => {
if scram.is_none() {
return Err(Error::AuthError(format!("SASL auth required and not password specified, auth passthrough (auth_query) method is currently unsupported for SASL auth {{ username: {:?}, database: {:?} }}", user.username, database)));
return Err(Error::ServerAuthError(
"SASL auth required and no password specified. \
Auth passthrough (auth_query) method is currently \
unsupported for SASL auth"
.into(),
server_identifier,
));
}
debug!("Starting SASL authentication");
let sasl_len = (len - 8) as usize;
let mut sasl_auth = vec![0u8; sasl_len];
match stream.read_exact(&mut sasl_auth).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading sasl message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"sasl message".into(),
server_identifier,
))
}
};
let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]);
@@ -233,7 +270,12 @@ impl Server {
match stream.read_exact(&mut sasl_data).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading sasl cont message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"sasl cont message".into(),
server_identifier,
))
}
};
let msg = BytesMut::from(&sasl_data[..]);
@@ -254,7 +296,12 @@ impl Server {
let mut sasl_final = vec![0u8; len as usize - 8];
match stream.read_exact(&mut sasl_final).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading sasl final message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"sasl final message".into(),
server_identifier,
))
}
};
match scram
@@ -284,7 +331,12 @@ impl Server {
'E' => {
let error_code = match stream.read_u8().await {
Ok(error_code) => error_code,
Err(_) => return Err(Error::SocketError(format!("Error reading error code message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"error code message".into(),
server_identifier,
))
}
};
trace!("Error: {}", error_code);
@@ -300,7 +352,12 @@ impl Server {
match stream.read_exact(&mut error).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading error message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"error message".into(),
server_identifier,
))
}
};
// TODO: the error message contains multiple fields; we can decode them and
@@ -319,7 +376,12 @@ impl Server {
match stream.read_exact(&mut param).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading parameter status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"parameter status message".into(),
server_identifier,
))
}
};
// Save the parameter so we can pass it to the client later.
@@ -336,12 +398,22 @@ impl Server {
// See: <https://www.postgresql.org/docs/12/protocol-message-formats.html>.
process_id = match stream.read_i32().await {
Ok(id) => id,
Err(_) => return Err(Error::SocketError(format!("Error reading process id message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"process id message".into(),
server_identifier,
))
}
};
secret_key = match stream.read_i32().await {
Ok(id) => id,
Err(_) => return Err(Error::SocketError(format!("Error reading secret key message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"secret key message".into(),
server_identifier,
))
}
};
}
@@ -351,7 +423,12 @@ impl Server {
match stream.read_exact(&mut idle).await {
Ok(_) => (),
Err(_) => return Err(Error::SocketError(format!("Error reading transaction status message on server startup {{ username: {:?}, database: {:?} }}", user.username, database))),
Err(_) => {
return Err(Error::ServerStartupError(
"transaction status message".into(),
server_identifier,
))
}
};
let (read, write) = stream.into_split();
@@ -413,7 +490,7 @@ impl Server {
Ok(stream) => stream,
Err(err) => {
error!("Could not connect to server: {}", err);
return Err(Error::SocketError(format!("Error reading cancel message")));
return Err(Error::SocketError("Error reading cancel message".into()));
}
};
configure_socket(&stream);