Server TLS (#417)

* Server TLS

* Finish up TLS

* thats it

* diff

* remove dead code

* maybe?

* dirty shutdown

* skip flakey test

* remove unused error

* fetch config once
This commit is contained in:
Lev Kokotov
2023-04-30 09:41:46 -07:00
committed by GitHub
parent 4a87b4807d
commit 0d504032b2
10 changed files with 311 additions and 33 deletions

View File

@@ -116,7 +116,10 @@ where
/// Send the startup packet the server. We're pretending we're a Pg client.
/// This tells the server which user we are and what database we want.
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
pub async fn startup<S>(stream: &mut S, user: &str, database: &str) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut bytes = BytesMut::with_capacity(25);
bytes.put_i32(196608); // Protocol number
@@ -150,6 +153,21 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
}
}
pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> {
let mut bytes = BytesMut::with_capacity(12);
bytes.put_i32(8);
bytes.put_i32(80877103);
match stream.write_all(&bytes).await {
Ok(_) => Ok(()),
Err(err) => Err(Error::SocketError(format!(
"Error writing SSLRequest to server socket - Error: {:?}",
err
))),
}
}
/// Parse the params the server sends as a key/value format.
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
let mut result = HashMap::new();
@@ -505,6 +523,29 @@ where
}
}
pub async fn write_all_flush<S>(stream: &mut S, buf: &[u8]) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
match stream.write_all(buf).await {
Ok(_) => match stream.flush().await {
Ok(_) => Ok(()),
Err(err) => {
return Err(Error::SocketError(format!(
"Error flushing socket - Error: {:?}",
err
)))
}
},
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
)))
}
}
}
/// Read a complete message from the socket.
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
where