2022-02-15 22:45:45 -08:00
|
|
|
/// Helper functions to send one-off protocol messages
|
|
|
|
|
/// and handle TcpStream (TCP socket).
|
2022-02-14 05:11:53 -08:00
|
|
|
use bytes::{Buf, BufMut, BytesMut};
|
2023-06-18 23:02:34 -07:00
|
|
|
use log::{debug, error};
|
2022-02-03 15:17:04 -08:00
|
|
|
use md5::{Digest, Md5};
|
2023-02-08 11:35:38 -06:00
|
|
|
use socket2::{SockRef, TcpKeepalive};
|
2022-06-27 16:45:41 -07:00
|
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
|
|
|
use tokio::net::TcpStream;
|
2022-02-03 13:35:40 -08:00
|
|
|
|
2023-06-16 12:57:44 -07:00
|
|
|
use crate::client::PREPARED_STATEMENT_COUNTER;
|
2023-02-08 11:35:38 -06:00
|
|
|
use crate::config::get_config;
|
2022-03-01 08:47:19 -08:00
|
|
|
use crate::errors::Error;
|
2023-06-16 12:57:44 -07:00
|
|
|
|
2023-08-09 13:14:05 -03:00
|
|
|
use crate::constants::MESSAGE_TERMINATOR;
|
2023-10-25 18:11:57 -04:00
|
|
|
use std::collections::hash_map::DefaultHasher;
|
2022-02-14 05:11:53 -08:00
|
|
|
use std::collections::HashMap;
|
2023-06-16 12:57:44 -07:00
|
|
|
use std::ffi::CString;
|
2023-08-09 13:14:05 -03:00
|
|
|
use std::fmt::{Display, Formatter};
|
2023-10-25 18:11:57 -04:00
|
|
|
use std::hash::{Hash, Hasher};
|
2023-01-19 10:19:49 -05:00
|
|
|
use std::io::{BufRead, Cursor};
|
2022-07-31 21:52:23 -05:00
|
|
|
use std::mem;
|
2023-08-09 13:14:05 -03:00
|
|
|
use std::str::FromStr;
|
2023-06-16 12:57:44 -07:00
|
|
|
use std::sync::atomic::Ordering;
|
2023-10-25 18:11:57 -04:00
|
|
|
use std::sync::Arc;
|
2023-02-08 11:35:38 -06:00
|
|
|
use std::time::Duration;
|
2022-02-14 05:11:53 -08:00
|
|
|
|
2022-03-01 08:47:19 -08:00
|
|
|
/// Postgres data type mappings
|
|
|
|
|
/// used in RowDescription ('T') message.
|
|
|
|
|
pub enum DataType {
|
|
|
|
|
Text,
|
|
|
|
|
Int4,
|
|
|
|
|
Numeric,
|
2023-05-03 09:13:05 -07:00
|
|
|
Bool,
|
|
|
|
|
Oid,
|
|
|
|
|
AnyArray,
|
|
|
|
|
Any,
|
2022-03-01 08:47:19 -08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl From<&DataType> for i32 {
|
|
|
|
|
fn from(data_type: &DataType) -> i32 {
|
|
|
|
|
match data_type {
|
|
|
|
|
DataType::Text => 25,
|
|
|
|
|
DataType::Int4 => 23,
|
|
|
|
|
DataType::Numeric => 1700,
|
2023-05-03 09:13:05 -07:00
|
|
|
DataType::Bool => 16,
|
|
|
|
|
DataType::Oid => 26,
|
|
|
|
|
DataType::AnyArray => 2277,
|
|
|
|
|
DataType::Any => 2276,
|
2022-03-01 08:47:19 -08:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2022-02-03 13:35:40 -08:00
|
|
|
|
2022-02-08 09:33:20 -08:00
|
|
|
/// Tell the client that authentication handshake completed successfully.
|
2022-06-27 09:46:33 -07:00
|
|
|
pub async fn auth_ok<S>(stream: &mut S) -> Result<(), Error>
|
2022-06-27 15:52:01 -07:00
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
|
|
|
{
|
2022-02-03 13:35:40 -08:00
|
|
|
let mut auth_ok = BytesMut::with_capacity(9);
|
|
|
|
|
|
|
|
|
|
auth_ok.put_u8(b'R');
|
|
|
|
|
auth_ok.put_i32(8);
|
|
|
|
|
auth_ok.put_i32(0);
|
|
|
|
|
|
2022-11-10 02:04:31 +08:00
|
|
|
write_all(stream, auth_ok).await
|
2022-02-03 13:35:40 -08:00
|
|
|
}
|
|
|
|
|
|
2024-09-05 09:40:03 -05:00
|
|
|
/// Tell the client to use clearr text auth
|
|
|
|
|
pub async fn clear_text_challenge<S>(stream: &mut S) -> Result<(), Error>
|
|
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
|
|
|
{
|
|
|
|
|
let mut auth_clear_text = BytesMut::with_capacity(9);
|
|
|
|
|
|
|
|
|
|
auth_clear_text.put_u8(b'R');
|
|
|
|
|
auth_clear_text.put_i32(8);
|
|
|
|
|
auth_clear_text.put_i32(3);
|
|
|
|
|
|
|
|
|
|
write_all(stream, auth_clear_text).await
|
|
|
|
|
}
|
|
|
|
|
|
2022-06-20 06:15:54 -07:00
|
|
|
/// Generate md5 password challenge.
|
2022-06-27 09:46:33 -07:00
|
|
|
pub async fn md5_challenge<S>(stream: &mut S) -> Result<[u8; 4], Error>
|
2022-06-27 15:52:01 -07:00
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
|
|
|
{
|
2022-06-20 06:15:54 -07:00
|
|
|
// let mut rng = rand::thread_rng();
|
|
|
|
|
let salt: [u8; 4] = [
|
|
|
|
|
rand::random(),
|
|
|
|
|
rand::random(),
|
|
|
|
|
rand::random(),
|
|
|
|
|
rand::random(),
|
|
|
|
|
];
|
|
|
|
|
|
|
|
|
|
let mut res = BytesMut::new();
|
|
|
|
|
res.put_u8(b'R');
|
|
|
|
|
res.put_i32(12);
|
|
|
|
|
res.put_i32(5); // MD5
|
|
|
|
|
res.put_slice(&salt[..]);
|
|
|
|
|
|
|
|
|
|
write_all(stream, res).await?;
|
|
|
|
|
Ok(salt)
|
|
|
|
|
}
|
|
|
|
|
|
2022-02-08 09:33:20 -08:00
|
|
|
/// Give the client the process_id and secret we generated
|
|
|
|
|
/// used in query cancellation.
|
2022-06-27 09:46:33 -07:00
|
|
|
pub async fn backend_key_data<S>(
|
|
|
|
|
stream: &mut S,
|
2022-02-04 09:28:52 -08:00
|
|
|
backend_id: i32,
|
|
|
|
|
secret_key: i32,
|
2022-06-27 09:46:33 -07:00
|
|
|
) -> Result<(), Error>
|
2022-06-27 15:52:01 -07:00
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
|
|
|
{
|
2022-02-04 09:28:52 -08:00
|
|
|
let mut key_data = BytesMut::from(&b"K"[..]);
|
|
|
|
|
key_data.put_i32(12);
|
|
|
|
|
key_data.put_i32(backend_id);
|
|
|
|
|
key_data.put_i32(secret_key);
|
|
|
|
|
|
2022-11-10 02:04:31 +08:00
|
|
|
write_all(stream, key_data).await
|
2022-02-04 09:28:52 -08:00
|
|
|
}
|
|
|
|
|
|
2022-02-24 08:44:41 -08:00
|
|
|
/// Construct a `Q`: Query message.
|
2022-02-19 08:57:24 -08:00
|
|
|
pub fn simple_query(query: &str) -> BytesMut {
|
|
|
|
|
let mut res = BytesMut::from(&b"Q"[..]);
|
|
|
|
|
let query = format!("{}\0", query);
|
|
|
|
|
|
|
|
|
|
res.put_i32(query.len() as i32 + 4);
|
2022-11-10 02:04:31 +08:00
|
|
|
res.put_slice(query.as_bytes());
|
2022-02-19 08:57:24 -08:00
|
|
|
|
|
|
|
|
res
|
|
|
|
|
}
|
|
|
|
|
|
2022-02-08 09:33:20 -08:00
|
|
|
/// Tell the client we're ready for another query.
|
2023-10-25 18:11:57 -04:00
|
|
|
pub async fn send_ready_for_query<S>(stream: &mut S) -> Result<(), Error>
|
2022-06-27 15:52:01 -07:00
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
|
|
|
{
|
2023-10-25 18:11:57 -04:00
|
|
|
write_all(stream, ready_for_query(false)).await
|
2022-02-03 13:35:40 -08:00
|
|
|
}
|
|
|
|
|
|
2022-02-08 09:33:20 -08:00
|
|
|
/// Send the startup packet the server. We're pretending we're a Pg client.
|
|
|
|
|
/// This tells the server which user we are and what database we want.
|
2023-04-30 09:41:46 -07:00
|
|
|
pub async fn startup<S>(stream: &mut S, user: &str, database: &str) -> Result<(), Error>
|
|
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
|
|
|
{
|
2022-02-03 15:17:04 -08:00
|
|
|
let mut bytes = BytesMut::with_capacity(25);
|
|
|
|
|
|
|
|
|
|
bytes.put_i32(196608); // Protocol number
|
|
|
|
|
|
|
|
|
|
// User
|
|
|
|
|
bytes.put(&b"user\0"[..]);
|
2022-11-10 02:04:31 +08:00
|
|
|
bytes.put_slice(user.as_bytes());
|
2022-02-03 15:17:04 -08:00
|
|
|
bytes.put_u8(0);
|
|
|
|
|
|
2023-08-10 11:18:46 -04:00
|
|
|
// Application name
|
|
|
|
|
bytes.put(&b"application_name\0"[..]);
|
|
|
|
|
bytes.put_slice(&b"pgcat\0"[..]);
|
|
|
|
|
|
2022-02-03 15:17:04 -08:00
|
|
|
// Database
|
|
|
|
|
bytes.put(&b"database\0"[..]);
|
2022-11-10 02:04:31 +08:00
|
|
|
bytes.put_slice(database.as_bytes());
|
2022-02-03 15:17:04 -08:00
|
|
|
bytes.put_u8(0);
|
|
|
|
|
bytes.put_u8(0); // Null terminator
|
|
|
|
|
|
|
|
|
|
let len = bytes.len() as i32 + 4i32;
|
|
|
|
|
|
|
|
|
|
let mut startup = BytesMut::with_capacity(len as usize);
|
|
|
|
|
|
|
|
|
|
startup.put_i32(len);
|
|
|
|
|
startup.put(bytes);
|
|
|
|
|
|
|
|
|
|
match stream.write_all(&startup).await {
|
|
|
|
|
Ok(_) => Ok(()),
|
2023-10-10 09:18:21 -07:00
|
|
|
Err(err) => Err(Error::SocketError(format!(
|
|
|
|
|
"Error writing startup to server socket - Error: {:?}",
|
|
|
|
|
err
|
|
|
|
|
))),
|
2022-02-03 15:17:04 -08:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-04-30 09:41:46 -07:00
|
|
|
pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> {
|
|
|
|
|
let mut bytes = BytesMut::with_capacity(12);
|
|
|
|
|
|
|
|
|
|
bytes.put_i32(8);
|
|
|
|
|
bytes.put_i32(80877103);
|
|
|
|
|
|
|
|
|
|
match stream.write_all(&bytes).await {
|
|
|
|
|
Ok(_) => Ok(()),
|
|
|
|
|
Err(err) => Err(Error::SocketError(format!(
|
|
|
|
|
"Error writing SSLRequest to server socket - Error: {:?}",
|
|
|
|
|
err
|
|
|
|
|
))),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2022-02-26 11:01:52 -08:00
|
|
|
/// Parse the params the server sends as a key/value format.
|
|
|
|
|
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
|
2022-02-14 05:11:53 -08:00
|
|
|
let mut result = HashMap::new();
|
|
|
|
|
let mut buf = Vec::new();
|
|
|
|
|
let mut tmp = String::new();
|
|
|
|
|
|
|
|
|
|
while bytes.has_remaining() {
|
|
|
|
|
let mut c = bytes.get_u8();
|
|
|
|
|
|
|
|
|
|
// Null-terminated C-strings.
|
|
|
|
|
while c != 0 {
|
|
|
|
|
tmp.push(c as char);
|
|
|
|
|
c = bytes.get_u8();
|
|
|
|
|
}
|
|
|
|
|
|
2022-11-10 02:04:31 +08:00
|
|
|
if !tmp.is_empty() {
|
2022-02-14 05:11:53 -08:00
|
|
|
buf.push(tmp.clone());
|
|
|
|
|
tmp.clear();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Expect pairs of name and value
|
|
|
|
|
// and at least one pair to be present.
|
2022-02-26 11:01:52 -08:00
|
|
|
if buf.len() % 2 != 0 || buf.len() < 2 {
|
2022-02-14 05:11:53 -08:00
|
|
|
return Err(Error::ClientBadStartup);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let mut i = 0;
|
|
|
|
|
while i < buf.len() {
|
|
|
|
|
let name = buf[i].clone();
|
|
|
|
|
let value = buf[i + 1].clone();
|
|
|
|
|
let _ = result.insert(name, value);
|
|
|
|
|
i += 2;
|
|
|
|
|
}
|
|
|
|
|
|
2022-02-26 11:01:52 -08:00
|
|
|
Ok(result)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Parse StartupMessage parameters.
|
|
|
|
|
/// e.g. user, database, application_name, etc.
|
|
|
|
|
pub fn parse_startup(bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
|
|
|
|
|
let result = parse_params(bytes)?;
|
|
|
|
|
|
2022-02-14 05:11:53 -08:00
|
|
|
// Minimum required parameters
|
|
|
|
|
// I want to have the user at the very minimum, according to the protocol spec.
|
|
|
|
|
if !result.contains_key("user") {
|
|
|
|
|
return Err(Error::ClientBadStartup);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok(result)
|
|
|
|
|
}
|
|
|
|
|
|
2022-06-20 06:15:54 -07:00
|
|
|
/// Create md5 password hash given a salt.
|
|
|
|
|
pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
|
2022-02-03 15:17:04 -08:00
|
|
|
let mut md5 = Md5::new();
|
|
|
|
|
|
|
|
|
|
// First pass
|
2023-10-10 09:18:21 -07:00
|
|
|
md5.update(password.as_bytes());
|
|
|
|
|
md5.update(user.as_bytes());
|
2022-02-03 15:17:04 -08:00
|
|
|
|
|
|
|
|
let output = md5.finalize_reset();
|
|
|
|
|
|
|
|
|
|
// Second pass
|
Auth passthrough (auth_query) (#266)
* Add a new exec_simple_query method
This adds a new `exec_simple_query` method so we can make 'out of band'
queries to servers that don't interfere with pools at all.
In order to reuse startup code for making these simple queries,
we need to set the stats (`Reporter`) optional, so using these
simple queries wont interfere with stats.
* Add auth passthough (auth_query)
Adds a feature that allows setting auth passthrough for md5 auth.
It adds 3 new (general and pool) config parameters:
- `auth_query`: An string containing a query that will be executed on boot
to obtain the hash of a given user. This query have to use a placeholder `$1`,
so pgcat can replace it with the user its trying to fetch the hash from.
- `auth_query_user`: The user to use for connecting to the server and executing the
auth_query.
- `auth_query_password`: The password to use for connecting to the server and executing the
auth_query.
The configuration can be done either on the general config (so pools share them) or in a per-pool basis.
The behavior is, at boot time, when validating server connections, a hash is fetched per server
and stored in the pool. When new server connections are created, and no cleartext password is specified,
the obtained hash is used for creating them, if the hash could not be obtained for whatever reason, it retries
it.
When client authentication is tried, it uses cleartext passwords if specified, it not, it checks whether
we have query_auth set up, if so, it tries to use the obtained hash for making client auth. If there is no
hash (we could not obtain one when validating the connection), a new fetch is tried.
Once we have a hash, we authenticate using it against whathever the client has sent us, if there is a failure
we refetch the hash and retry auth (so password changes can be done).
The idea with this 'retrial' mechanism is to make it fault tolerant, so if for whatever reason hash could not be
obtained during connection validation, or the password has change, we can still connect later.
* Add documentation for Auth passthrough
2023-03-30 22:29:23 +02:00
|
|
|
md5_hash_second_pass(&(format!("{:x}", output)), salt)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn md5_hash_second_pass(hash: &str, salt: &[u8]) -> Vec<u8> {
|
|
|
|
|
let mut md5 = Md5::new();
|
|
|
|
|
// Second pass
|
|
|
|
|
md5.update(hash);
|
2022-02-03 15:17:04 -08:00
|
|
|
md5.update(salt);
|
|
|
|
|
|
|
|
|
|
let mut password = format!("md5{:x}", md5.finalize())
|
|
|
|
|
.chars()
|
|
|
|
|
.map(|x| x as u8)
|
|
|
|
|
.collect::<Vec<u8>>();
|
|
|
|
|
password.push(0);
|
|
|
|
|
|
2022-06-20 06:15:54 -07:00
|
|
|
password
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Send password challenge response to the server.
|
|
|
|
|
/// This is the MD5 challenge.
|
2022-06-27 09:46:33 -07:00
|
|
|
pub async fn md5_password<S>(
|
|
|
|
|
stream: &mut S,
|
2022-06-20 06:15:54 -07:00
|
|
|
user: &str,
|
|
|
|
|
password: &str,
|
|
|
|
|
salt: &[u8],
|
2022-06-27 09:46:33 -07:00
|
|
|
) -> Result<(), Error>
|
2022-06-27 15:52:01 -07:00
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
|
|
|
{
|
2022-06-20 06:15:54 -07:00
|
|
|
let password = md5_hash_password(user, password, salt);
|
|
|
|
|
|
2023-10-10 09:18:21 -07:00
|
|
|
let mut message = BytesMut::with_capacity(password.len() + 5);
|
Auth passthrough (auth_query) (#266)
* Add a new exec_simple_query method
This adds a new `exec_simple_query` method so we can make 'out of band'
queries to servers that don't interfere with pools at all.
In order to reuse startup code for making these simple queries,
we need to set the stats (`Reporter`) optional, so using these
simple queries wont interfere with stats.
* Add auth passthough (auth_query)
Adds a feature that allows setting auth passthrough for md5 auth.
It adds 3 new (general and pool) config parameters:
- `auth_query`: An string containing a query that will be executed on boot
to obtain the hash of a given user. This query have to use a placeholder `$1`,
so pgcat can replace it with the user its trying to fetch the hash from.
- `auth_query_user`: The user to use for connecting to the server and executing the
auth_query.
- `auth_query_password`: The password to use for connecting to the server and executing the
auth_query.
The configuration can be done either on the general config (so pools share them) or in a per-pool basis.
The behavior is, at boot time, when validating server connections, a hash is fetched per server
and stored in the pool. When new server connections are created, and no cleartext password is specified,
the obtained hash is used for creating them, if the hash could not be obtained for whatever reason, it retries
it.
When client authentication is tried, it uses cleartext passwords if specified, it not, it checks whether
we have query_auth set up, if so, it tries to use the obtained hash for making client auth. If there is no
hash (we could not obtain one when validating the connection), a new fetch is tried.
Once we have a hash, we authenticate using it against whathever the client has sent us, if there is a failure
we refetch the hash and retry auth (so password changes can be done).
The idea with this 'retrial' mechanism is to make it fault tolerant, so if for whatever reason hash could not be
obtained during connection validation, or the password has change, we can still connect later.
* Add documentation for Auth passthrough
2023-03-30 22:29:23 +02:00
|
|
|
|
|
|
|
|
message.put_u8(b'p');
|
|
|
|
|
message.put_i32(password.len() as i32 + 4);
|
|
|
|
|
message.put_slice(&password[..]);
|
|
|
|
|
|
|
|
|
|
write_all(stream, message).await
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub async fn md5_password_with_hash<S>(stream: &mut S, hash: &str, salt: &[u8]) -> Result<(), Error>
|
|
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
|
|
|
{
|
|
|
|
|
let password = md5_hash_second_pass(hash, salt);
|
2023-10-10 09:18:21 -07:00
|
|
|
let mut message = BytesMut::with_capacity(password.len() + 5);
|
2022-02-15 22:45:45 -08:00
|
|
|
|
2022-02-03 15:17:04 -08:00
|
|
|
message.put_u8(b'p');
|
|
|
|
|
message.put_i32(password.len() as i32 + 4);
|
|
|
|
|
message.put_slice(&password[..]);
|
|
|
|
|
|
2022-11-10 02:04:31 +08:00
|
|
|
write_all(stream, message).await
|
2022-02-03 15:17:04 -08:00
|
|
|
}
|
|
|
|
|
|
2022-02-09 20:02:20 -08:00
|
|
|
/// Implements a response to our custom `SET SHARDING KEY`
|
|
|
|
|
/// and `SET SERVER ROLE` commands.
|
2022-02-09 06:51:31 -08:00
|
|
|
/// This tells the client we're ready for the next query.
|
2022-06-27 15:52:01 -07:00
|
|
|
pub async fn custom_protocol_response_ok<S>(stream: &mut S, message: &str) -> Result<(), Error>
|
|
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
|
|
|
{
|
2022-02-08 13:11:50 -08:00
|
|
|
let mut res = BytesMut::with_capacity(25);
|
|
|
|
|
|
2022-02-09 20:02:20 -08:00
|
|
|
let set_complete = BytesMut::from(&format!("{}\0", message)[..]);
|
2022-02-08 13:11:50 -08:00
|
|
|
let len = (set_complete.len() + 4) as i32;
|
|
|
|
|
|
2022-02-09 06:51:31 -08:00
|
|
|
// CommandComplete
|
2022-02-08 13:11:50 -08:00
|
|
|
res.put_u8(b'C');
|
|
|
|
|
res.put_i32(len);
|
|
|
|
|
res.put_slice(&set_complete[..]);
|
|
|
|
|
|
2023-01-17 20:39:55 -05:00
|
|
|
write_all_half(stream, &res).await?;
|
2023-10-25 18:11:57 -04:00
|
|
|
send_ready_for_query(stream).await
|
2022-02-08 13:11:50 -08:00
|
|
|
}
|
|
|
|
|
|
2022-02-16 22:52:11 -08:00
|
|
|
/// Send a custom error message to the client.
|
|
|
|
|
/// Tell the client we are ready for the next query and no rollback is necessary.
|
2022-03-10 01:33:29 -08:00
|
|
|
/// Docs on error codes: <https://www.postgresql.org/docs/12/errcodes-appendix.html>.
|
2022-06-27 09:46:33 -07:00
|
|
|
pub async fn error_response<S>(stream: &mut S, message: &str) -> Result<(), Error>
|
2022-08-08 19:01:24 -04:00
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
|
|
|
{
|
|
|
|
|
error_response_terminal(stream, message).await?;
|
2023-10-25 18:11:57 -04:00
|
|
|
send_ready_for_query(stream).await
|
2022-08-08 19:01:24 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Send a custom error message to the client.
|
|
|
|
|
/// Tell the client we are ready for the next query and no rollback is necessary.
|
|
|
|
|
/// Docs on error codes: <https://www.postgresql.org/docs/12/errcodes-appendix.html>.
|
|
|
|
|
pub async fn error_response_terminal<S>(stream: &mut S, message: &str) -> Result<(), Error>
|
2022-06-27 15:52:01 -07:00
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
|
|
|
{
|
2022-02-16 22:52:11 -08:00
|
|
|
let mut error = BytesMut::new();
|
|
|
|
|
|
|
|
|
|
// Error level
|
|
|
|
|
error.put_u8(b'S');
|
|
|
|
|
error.put_slice(&b"FATAL\0"[..]);
|
|
|
|
|
|
|
|
|
|
// Error level (non-translatable)
|
|
|
|
|
error.put_u8(b'V');
|
|
|
|
|
error.put_slice(&b"FATAL\0"[..]);
|
|
|
|
|
|
|
|
|
|
// Error code: not sure how much this matters.
|
|
|
|
|
error.put_u8(b'C');
|
|
|
|
|
error.put_slice(&b"58000\0"[..]); // system_error, see Appendix A.
|
|
|
|
|
|
|
|
|
|
// The short error message.
|
|
|
|
|
error.put_u8(b'M');
|
2022-11-10 02:04:31 +08:00
|
|
|
error.put_slice(format!("{}\0", message).as_bytes());
|
2022-02-16 22:52:11 -08:00
|
|
|
|
|
|
|
|
// No more fields follow.
|
|
|
|
|
error.put_u8(0);
|
|
|
|
|
|
|
|
|
|
// Compose the two message reply.
|
2022-08-08 19:01:24 -04:00
|
|
|
let mut res = BytesMut::with_capacity(error.len() + 5);
|
2022-02-16 22:52:11 -08:00
|
|
|
|
|
|
|
|
res.put_u8(b'E');
|
|
|
|
|
res.put_i32(error.len() as i32 + 4);
|
|
|
|
|
res.put(error);
|
|
|
|
|
|
2023-01-17 20:39:55 -05:00
|
|
|
write_all_half(stream, &res).await
|
2022-02-16 22:52:11 -08:00
|
|
|
}
|
|
|
|
|
|
2022-06-27 09:46:33 -07:00
|
|
|
pub async fn wrong_password<S>(stream: &mut S, user: &str) -> Result<(), Error>
|
2022-06-27 15:52:01 -07:00
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
|
|
|
{
|
2022-06-20 06:15:54 -07:00
|
|
|
let mut error = BytesMut::new();
|
|
|
|
|
|
|
|
|
|
// Error level
|
|
|
|
|
error.put_u8(b'S');
|
|
|
|
|
error.put_slice(&b"FATAL\0"[..]);
|
|
|
|
|
|
|
|
|
|
// Error level (non-translatable)
|
|
|
|
|
error.put_u8(b'V');
|
|
|
|
|
error.put_slice(&b"FATAL\0"[..]);
|
|
|
|
|
|
|
|
|
|
// Error code: not sure how much this matters.
|
|
|
|
|
error.put_u8(b'C');
|
|
|
|
|
error.put_slice(&b"28P01\0"[..]); // system_error, see Appendix A.
|
|
|
|
|
|
|
|
|
|
// The short error message.
|
|
|
|
|
error.put_u8(b'M');
|
2022-11-10 02:04:31 +08:00
|
|
|
error.put_slice(format!("password authentication failed for user \"{}\"\0", user).as_bytes());
|
2022-06-20 06:15:54 -07:00
|
|
|
|
|
|
|
|
// No more fields follow.
|
|
|
|
|
error.put_u8(0);
|
|
|
|
|
|
|
|
|
|
// Compose the two message reply.
|
|
|
|
|
let mut res = BytesMut::new();
|
|
|
|
|
|
|
|
|
|
res.put_u8(b'E');
|
|
|
|
|
res.put_i32(error.len() as i32 + 4);
|
|
|
|
|
|
|
|
|
|
res.put(error);
|
|
|
|
|
|
|
|
|
|
write_all(stream, res).await
|
|
|
|
|
}
|
|
|
|
|
|
2022-02-19 08:57:24 -08:00
|
|
|
/// Respond to a SHOW SHARD command.
|
2022-06-27 15:52:01 -07:00
|
|
|
pub async fn show_response<S>(stream: &mut S, name: &str, value: &str) -> Result<(), Error>
|
|
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
|
|
|
{
|
2022-02-19 08:57:24 -08:00
|
|
|
// A SELECT response consists of:
|
|
|
|
|
// 1. RowDescription
|
|
|
|
|
// 2. One or more DataRow
|
|
|
|
|
// 3. CommandComplete
|
|
|
|
|
// 4. ReadyForQuery
|
|
|
|
|
|
2022-03-01 08:47:19 -08:00
|
|
|
// The final messages sent to the client
|
|
|
|
|
let mut res = BytesMut::new();
|
2022-02-19 08:57:24 -08:00
|
|
|
|
2022-03-01 08:47:19 -08:00
|
|
|
// RowDescription
|
|
|
|
|
res.put(row_description(&vec![(name, DataType::Text)]));
|
2022-02-19 08:57:24 -08:00
|
|
|
|
2022-03-01 08:47:19 -08:00
|
|
|
// DataRow
|
|
|
|
|
res.put(data_row(&vec![value.to_string()]));
|
2022-02-19 08:57:24 -08:00
|
|
|
|
2022-03-01 08:47:19 -08:00
|
|
|
// CommandComplete
|
|
|
|
|
res.put(command_complete("SELECT 1"));
|
2022-02-19 08:57:24 -08:00
|
|
|
|
2023-01-17 20:39:55 -05:00
|
|
|
write_all_half(stream, &res).await?;
|
2023-10-25 18:11:57 -04:00
|
|
|
send_ready_for_query(stream).await
|
2022-03-01 08:47:19 -08:00
|
|
|
}
|
2022-02-19 08:57:24 -08:00
|
|
|
|
2022-03-01 08:47:19 -08:00
|
|
|
pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
|
|
|
|
|
let mut res = BytesMut::new();
|
|
|
|
|
let mut row_desc = BytesMut::new();
|
2022-02-19 08:57:24 -08:00
|
|
|
|
2023-04-11 09:37:16 +08:00
|
|
|
// how many columns we are storing
|
2022-03-01 08:47:19 -08:00
|
|
|
row_desc.put_i16(columns.len() as i16);
|
2022-02-19 08:57:24 -08:00
|
|
|
|
2022-03-01 08:47:19 -08:00
|
|
|
for (name, data_type) in columns {
|
|
|
|
|
// Column name
|
2022-11-10 02:04:31 +08:00
|
|
|
row_desc.put_slice(format!("{}\0", name).as_bytes());
|
2022-02-19 08:57:24 -08:00
|
|
|
|
2022-03-01 08:47:19 -08:00
|
|
|
// Doesn't belong to any table
|
|
|
|
|
row_desc.put_i32(0);
|
2022-02-19 08:57:24 -08:00
|
|
|
|
2022-03-01 08:47:19 -08:00
|
|
|
// Doesn't belong to any table
|
|
|
|
|
row_desc.put_i16(0);
|
2022-02-19 08:57:24 -08:00
|
|
|
|
2022-03-01 08:47:19 -08:00
|
|
|
// Text
|
|
|
|
|
row_desc.put_i32(data_type.into());
|
2022-02-19 08:57:24 -08:00
|
|
|
|
2022-03-01 08:47:19 -08:00
|
|
|
// Text size = variable (-1)
|
|
|
|
|
let type_size = match data_type {
|
|
|
|
|
DataType::Text => -1,
|
|
|
|
|
DataType::Int4 => 4,
|
|
|
|
|
DataType::Numeric => -1,
|
2023-05-03 09:13:05 -07:00
|
|
|
DataType::Bool => 1,
|
|
|
|
|
DataType::Oid => 4,
|
|
|
|
|
DataType::AnyArray => -1,
|
|
|
|
|
DataType::Any => -1,
|
2022-03-01 08:47:19 -08:00
|
|
|
};
|
2022-02-19 08:57:24 -08:00
|
|
|
|
2022-03-01 08:47:19 -08:00
|
|
|
row_desc.put_i16(type_size);
|
2022-02-19 08:57:24 -08:00
|
|
|
|
2022-03-01 08:47:19 -08:00
|
|
|
// Type modifier: none that I know
|
|
|
|
|
row_desc.put_i32(-1);
|
2022-02-19 08:57:24 -08:00
|
|
|
|
2022-03-01 08:47:19 -08:00
|
|
|
// Format being used: text (0), binary (1)
|
|
|
|
|
row_desc.put_i16(0);
|
|
|
|
|
}
|
2022-02-19 08:57:24 -08:00
|
|
|
|
|
|
|
|
res.put_u8(b'T');
|
|
|
|
|
res.put_i32(row_desc.len() as i32 + 4);
|
|
|
|
|
res.put(row_desc);
|
|
|
|
|
|
2022-03-01 08:47:19 -08:00
|
|
|
res
|
|
|
|
|
}
|
|
|
|
|
|
2022-03-10 01:33:29 -08:00
|
|
|
/// Create a DataRow message.
|
2022-03-01 08:47:19 -08:00
|
|
|
pub fn data_row(row: &Vec<String>) -> BytesMut {
|
|
|
|
|
let mut res = BytesMut::new();
|
|
|
|
|
let mut data_row = BytesMut::new();
|
|
|
|
|
|
|
|
|
|
data_row.put_i16(row.len() as i16);
|
|
|
|
|
|
|
|
|
|
for column in row {
|
|
|
|
|
let column = column.as_bytes();
|
|
|
|
|
data_row.put_i32(column.len() as i32);
|
2022-11-10 02:04:31 +08:00
|
|
|
data_row.put_slice(column);
|
2022-03-01 08:47:19 -08:00
|
|
|
}
|
|
|
|
|
|
2022-02-19 08:57:24 -08:00
|
|
|
res.put_u8(b'D');
|
|
|
|
|
res.put_i32(data_row.len() as i32 + 4);
|
|
|
|
|
res.put(data_row);
|
|
|
|
|
|
2022-03-01 08:47:19 -08:00
|
|
|
res
|
|
|
|
|
}
|
2022-02-19 08:57:24 -08:00
|
|
|
|
2023-05-03 09:13:05 -07:00
|
|
|
pub fn data_row_nullable(row: &Vec<Option<String>>) -> BytesMut {
|
|
|
|
|
let mut res = BytesMut::new();
|
|
|
|
|
let mut data_row = BytesMut::new();
|
|
|
|
|
|
|
|
|
|
data_row.put_i16(row.len() as i16);
|
|
|
|
|
|
|
|
|
|
for column in row {
|
|
|
|
|
if let Some(column) = column {
|
|
|
|
|
let column = column.as_bytes();
|
|
|
|
|
data_row.put_i32(column.len() as i32);
|
|
|
|
|
data_row.put_slice(column);
|
|
|
|
|
} else {
|
2023-10-10 09:18:21 -07:00
|
|
|
data_row.put_i32(-1_i32);
|
2023-05-03 09:13:05 -07:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
res.put_u8(b'D');
|
|
|
|
|
res.put_i32(data_row.len() as i32 + 4);
|
|
|
|
|
res.put(data_row);
|
|
|
|
|
|
|
|
|
|
res
|
|
|
|
|
}
|
|
|
|
|
|
2022-03-10 01:33:29 -08:00
|
|
|
/// Create a CommandComplete message.
|
2022-03-01 08:47:19 -08:00
|
|
|
pub fn command_complete(command: &str) -> BytesMut {
|
|
|
|
|
let cmd = BytesMut::from(format!("{}\0", command).as_bytes());
|
|
|
|
|
let mut res = BytesMut::new();
|
|
|
|
|
res.put_u8(b'C');
|
|
|
|
|
res.put_i32(cmd.len() as i32 + 4);
|
|
|
|
|
res.put(cmd);
|
|
|
|
|
res
|
2022-02-19 08:57:24 -08:00
|
|
|
}
|
|
|
|
|
|
2023-07-14 02:40:04 -03:00
|
|
|
/// Create a notify message.
|
|
|
|
|
pub fn notify(message: &str, details: String) -> BytesMut {
|
|
|
|
|
let mut notify_cmd = BytesMut::new();
|
|
|
|
|
|
|
|
|
|
notify_cmd.put_slice("SNOTICE\0".as_bytes());
|
|
|
|
|
notify_cmd.put_slice("C00000\0".as_bytes());
|
|
|
|
|
notify_cmd.put_slice(format!("M{}\0", message).as_bytes());
|
|
|
|
|
notify_cmd.put_slice(format!("D{}\0", details).as_bytes());
|
|
|
|
|
|
|
|
|
|
// this extra byte says that is the end of the package
|
|
|
|
|
notify_cmd.put_u8(0);
|
|
|
|
|
|
|
|
|
|
let mut res = BytesMut::new();
|
|
|
|
|
res.put_u8(b'N');
|
|
|
|
|
res.put_i32(notify_cmd.len() as i32 + 4);
|
|
|
|
|
res.put(notify_cmd);
|
|
|
|
|
|
|
|
|
|
res
|
|
|
|
|
}
|
|
|
|
|
|
2023-06-16 12:57:44 -07:00
|
|
|
pub fn flush() -> BytesMut {
|
|
|
|
|
let mut bytes = BytesMut::new();
|
|
|
|
|
bytes.put_u8(b'H');
|
|
|
|
|
bytes.put_i32(4);
|
|
|
|
|
bytes
|
|
|
|
|
}
|
|
|
|
|
|
2023-10-25 18:11:57 -04:00
|
|
|
pub fn sync() -> BytesMut {
|
|
|
|
|
let mut bytes = BytesMut::with_capacity(mem::size_of::<u8>() + mem::size_of::<i32>());
|
|
|
|
|
bytes.put_u8(b'S');
|
|
|
|
|
bytes.put_i32(4);
|
|
|
|
|
bytes
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn parse_complete() -> BytesMut {
|
|
|
|
|
let mut bytes = BytesMut::with_capacity(mem::size_of::<u8>() + mem::size_of::<i32>());
|
|
|
|
|
|
|
|
|
|
bytes.put_u8(b'1');
|
|
|
|
|
bytes.put_i32(4);
|
|
|
|
|
bytes
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn ready_for_query(in_transaction: bool) -> BytesMut {
|
|
|
|
|
let mut bytes = BytesMut::with_capacity(
|
|
|
|
|
mem::size_of::<u8>() + mem::size_of::<i32>() + mem::size_of::<u8>(),
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
bytes.put_u8(b'Z');
|
|
|
|
|
bytes.put_i32(5);
|
|
|
|
|
if in_transaction {
|
|
|
|
|
bytes.put_u8(b'T');
|
|
|
|
|
} else {
|
|
|
|
|
bytes.put_u8(b'I');
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bytes
|
|
|
|
|
}
|
|
|
|
|
|
2022-02-08 09:33:20 -08:00
|
|
|
/// Write all data in the buffer to the TcpStream.
|
2022-06-27 09:46:33 -07:00
|
|
|
pub async fn write_all<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
|
2022-06-27 15:52:01 -07:00
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
|
|
|
{
|
2022-02-03 13:35:40 -08:00
|
|
|
match stream.write_all(&buf).await {
|
|
|
|
|
Ok(_) => Ok(()),
|
2023-10-10 09:18:21 -07:00
|
|
|
Err(err) => Err(Error::SocketError(format!(
|
|
|
|
|
"Error writing to socket - Error: {:?}",
|
|
|
|
|
err
|
|
|
|
|
))),
|
2022-02-03 13:35:40 -08:00
|
|
|
}
|
2022-02-03 13:54:07 -08:00
|
|
|
}
|
|
|
|
|
|
2022-02-08 09:33:20 -08:00
|
|
|
/// Write all the data in the buffer to the TcpStream, write owned half (see mpsc).
|
2023-01-17 20:39:55 -05:00
|
|
|
pub async fn write_all_half<S>(stream: &mut S, buf: &BytesMut) -> Result<(), Error>
|
2022-06-27 15:52:01 -07:00
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
|
|
|
{
|
2023-01-17 20:39:55 -05:00
|
|
|
match stream.write_all(buf).await {
|
2022-02-03 15:17:04 -08:00
|
|
|
Ok(_) => Ok(()),
|
2023-10-10 09:18:21 -07:00
|
|
|
Err(err) => Err(Error::SocketError(format!(
|
|
|
|
|
"Error writing to socket - Error: {:?}",
|
|
|
|
|
err
|
|
|
|
|
))),
|
2022-02-03 15:17:04 -08:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-04-30 09:41:46 -07:00
|
|
|
pub async fn write_all_flush<S>(stream: &mut S, buf: &[u8]) -> Result<(), Error>
|
|
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
|
|
|
|
{
|
|
|
|
|
match stream.write_all(buf).await {
|
|
|
|
|
Ok(_) => match stream.flush().await {
|
|
|
|
|
Ok(_) => Ok(()),
|
2023-10-10 09:18:21 -07:00
|
|
|
Err(err) => Err(Error::SocketError(format!(
|
|
|
|
|
"Error flushing socket - Error: {:?}",
|
2023-04-30 09:41:46 -07:00
|
|
|
err
|
2023-10-10 09:18:21 -07:00
|
|
|
))),
|
|
|
|
|
},
|
|
|
|
|
Err(err) => Err(Error::SocketError(format!(
|
|
|
|
|
"Error writing to socket - Error: {:?}",
|
|
|
|
|
err
|
|
|
|
|
))),
|
2023-04-30 09:41:46 -07:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2022-02-03 13:54:07 -08:00
|
|
|
/// Read a complete message from the socket.
|
2022-06-27 15:52:01 -07:00
|
|
|
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
|
|
|
|
|
where
|
|
|
|
|
S: tokio::io::AsyncRead + std::marker::Unpin,
|
|
|
|
|
{
|
2022-02-03 13:54:07 -08:00
|
|
|
let code = match stream.read_u8().await {
|
|
|
|
|
Ok(code) => code,
|
2023-01-19 05:18:08 -06:00
|
|
|
Err(err) => {
|
2022-11-17 09:24:39 -08:00
|
|
|
return Err(Error::SocketError(format!(
|
2023-01-19 05:18:08 -06:00
|
|
|
"Error reading message code from socket - Error {:?}",
|
|
|
|
|
err
|
2022-11-17 09:24:39 -08:00
|
|
|
)))
|
|
|
|
|
}
|
2022-02-03 13:54:07 -08:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
let len = match stream.read_i32().await {
|
|
|
|
|
Ok(len) => len,
|
2023-01-19 05:18:08 -06:00
|
|
|
Err(err) => {
|
2022-11-17 09:24:39 -08:00
|
|
|
return Err(Error::SocketError(format!(
|
2023-01-19 05:18:08 -06:00
|
|
|
"Error reading message len from socket - Code: {:?}, Error: {:?}",
|
|
|
|
|
code, err
|
2022-11-17 09:24:39 -08:00
|
|
|
)))
|
|
|
|
|
}
|
2022-02-03 13:54:07 -08:00
|
|
|
};
|
|
|
|
|
|
2023-01-16 23:22:06 -05:00
|
|
|
let mut bytes = BytesMut::with_capacity(len as usize + 1);
|
|
|
|
|
|
|
|
|
|
bytes.put_u8(code);
|
|
|
|
|
bytes.put_i32(len);
|
|
|
|
|
|
|
|
|
|
bytes.resize(bytes.len() + len as usize - mem::size_of::<i32>(), b'0');
|
2022-02-03 13:54:07 -08:00
|
|
|
|
2023-03-17 12:31:43 -05:00
|
|
|
let slice_start = mem::size_of::<u8>() + mem::size_of::<i32>();
|
|
|
|
|
let slice_end = slice_start + len as usize - mem::size_of::<i32>();
|
|
|
|
|
|
|
|
|
|
// Avoids a panic
|
|
|
|
|
if slice_end < slice_start {
|
|
|
|
|
return Err(Error::SocketError(format!(
|
|
|
|
|
"Error reading message from socket - Code: {:?} - Length {:?}, Error: {:?}",
|
|
|
|
|
code, len, "Unexpected length value for message"
|
|
|
|
|
)));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
match stream.read_exact(&mut bytes[slice_start..slice_end]).await {
|
2022-02-03 13:54:07 -08:00
|
|
|
Ok(_) => (),
|
2023-01-19 05:18:08 -06:00
|
|
|
Err(err) => {
|
2022-11-17 09:24:39 -08:00
|
|
|
return Err(Error::SocketError(format!(
|
2023-01-19 05:18:08 -06:00
|
|
|
"Error reading message from socket - Code: {:?}, Error: {:?}",
|
|
|
|
|
code, err
|
2022-11-17 09:24:39 -08:00
|
|
|
)))
|
|
|
|
|
}
|
2022-02-03 13:54:07 -08:00
|
|
|
};
|
|
|
|
|
|
|
|
|
|
Ok(bytes)
|
2022-02-03 15:17:04 -08:00
|
|
|
}
|
2022-07-31 21:52:23 -05:00
|
|
|
|
2022-09-20 21:47:32 -04:00
|
|
|
pub fn server_parameter_message(key: &str, value: &str) -> BytesMut {
|
2022-07-31 21:52:23 -05:00
|
|
|
let mut server_info = BytesMut::new();
|
|
|
|
|
|
|
|
|
|
let null_byte_size = 1;
|
|
|
|
|
let len: usize =
|
|
|
|
|
mem::size_of::<i32>() + key.len() + null_byte_size + value.len() + null_byte_size;
|
|
|
|
|
|
|
|
|
|
server_info.put_slice("S".as_bytes());
|
|
|
|
|
server_info.put_i32(len.try_into().unwrap());
|
|
|
|
|
server_info.put_slice(key.as_bytes());
|
|
|
|
|
server_info.put_bytes(0, 1);
|
|
|
|
|
server_info.put_slice(value.as_bytes());
|
|
|
|
|
server_info.put_bytes(0, 1);
|
|
|
|
|
|
2022-11-10 02:04:31 +08:00
|
|
|
server_info
|
2022-07-31 21:52:23 -05:00
|
|
|
}
|
2023-01-19 10:19:49 -05:00
|
|
|
|
2023-02-08 11:35:38 -06:00
|
|
|
pub fn configure_socket(stream: &TcpStream) {
|
|
|
|
|
let sock_ref = SockRef::from(stream);
|
|
|
|
|
let conf = get_config();
|
|
|
|
|
|
2023-07-12 13:24:30 -05:00
|
|
|
#[cfg(target_os = "linux")]
|
|
|
|
|
match sock_ref.set_tcp_user_timeout(Some(Duration::from_millis(conf.general.tcp_user_timeout)))
|
|
|
|
|
{
|
|
|
|
|
Ok(_) => (),
|
|
|
|
|
Err(err) => error!("Could not configure tcp_user_timeout for socket: {}", err),
|
|
|
|
|
}
|
|
|
|
|
|
2023-02-08 11:35:38 -06:00
|
|
|
match sock_ref.set_keepalive(true) {
|
|
|
|
|
Ok(_) => {
|
|
|
|
|
match sock_ref.set_tcp_keepalive(
|
|
|
|
|
&TcpKeepalive::new()
|
|
|
|
|
.with_interval(Duration::from_secs(conf.general.tcp_keepalives_interval))
|
|
|
|
|
.with_retries(conf.general.tcp_keepalives_count)
|
|
|
|
|
.with_time(Duration::from_secs(conf.general.tcp_keepalives_idle)),
|
|
|
|
|
) {
|
|
|
|
|
Ok(_) => (),
|
2023-07-12 13:24:30 -05:00
|
|
|
Err(err) => error!("Could not configure tcp_keepalive for socket: {}", err),
|
2023-02-08 11:35:38 -06:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
Err(err) => error!("Could not configure socket: {}", err),
|
|
|
|
|
}
|
2024-05-27 00:47:21 +03:00
|
|
|
match sock_ref.set_nodelay(true) {
|
|
|
|
|
Ok(_) => (),
|
|
|
|
|
Err(err) => error!("Could not configure TCP_NODELAY for socket: {}", err),
|
|
|
|
|
}
|
2023-02-08 11:35:38 -06:00
|
|
|
}
|
|
|
|
|
|
2023-01-19 10:19:49 -05:00
|
|
|
pub trait BytesMutReader {
|
|
|
|
|
fn read_string(&mut self) -> Result<String, Error>;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl BytesMutReader for Cursor<&BytesMut> {
|
|
|
|
|
/// Should only be used when reading strings from the message protocol.
|
|
|
|
|
/// Can be used to read multiple strings from the same message which are separated by the null byte
|
|
|
|
|
fn read_string(&mut self) -> Result<String, Error> {
|
|
|
|
|
let mut buf = vec![];
|
|
|
|
|
match self.read_until(b'\0', &mut buf) {
|
|
|
|
|
Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()),
|
2023-10-10 09:18:21 -07:00
|
|
|
Err(err) => Err(Error::ParseBytesError(err.to_string())),
|
2023-01-19 10:19:49 -05:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2023-06-16 12:57:44 -07:00
|
|
|
|
2023-08-10 11:18:46 -04:00
|
|
|
impl BytesMutReader for BytesMut {
|
|
|
|
|
/// Should only be used when reading strings from the message protocol.
|
|
|
|
|
/// Can be used to read multiple strings from the same message which are separated by the null byte
|
|
|
|
|
fn read_string(&mut self) -> Result<String, Error> {
|
|
|
|
|
let null_index = self.iter().position(|&byte| byte == b'\0');
|
|
|
|
|
|
|
|
|
|
match null_index {
|
|
|
|
|
Some(index) => {
|
|
|
|
|
let string_bytes = self.split_to(index + 1);
|
|
|
|
|
Ok(String::from_utf8_lossy(&string_bytes[..string_bytes.len() - 1]).to_string())
|
|
|
|
|
}
|
2023-10-10 09:18:21 -07:00
|
|
|
None => Err(Error::ParseBytesError("Could not read string".to_string())),
|
2023-08-10 11:18:46 -04:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2023-10-25 18:11:57 -04:00
|
|
|
|
|
|
|
|
pub enum ExtendedProtocolData {
|
|
|
|
|
Parse {
|
|
|
|
|
data: BytesMut,
|
|
|
|
|
metadata: Option<(Arc<Parse>, u64)>,
|
|
|
|
|
},
|
|
|
|
|
Bind {
|
|
|
|
|
data: BytesMut,
|
|
|
|
|
metadata: Option<String>,
|
|
|
|
|
},
|
|
|
|
|
Describe {
|
|
|
|
|
data: BytesMut,
|
|
|
|
|
metadata: Option<String>,
|
|
|
|
|
},
|
|
|
|
|
Execute {
|
|
|
|
|
data: BytesMut,
|
|
|
|
|
},
|
|
|
|
|
Close {
|
|
|
|
|
data: BytesMut,
|
|
|
|
|
close: Close,
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl ExtendedProtocolData {
|
|
|
|
|
pub fn create_new_parse(data: BytesMut, metadata: Option<(Arc<Parse>, u64)>) -> Self {
|
|
|
|
|
Self::Parse { data, metadata }
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn create_new_bind(data: BytesMut, metadata: Option<String>) -> Self {
|
|
|
|
|
Self::Bind { data, metadata }
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn create_new_describe(data: BytesMut, metadata: Option<String>) -> Self {
|
|
|
|
|
Self::Describe { data, metadata }
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn create_new_execute(data: BytesMut) -> Self {
|
|
|
|
|
Self::Execute { data }
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn create_new_close(data: BytesMut, close: Close) -> Self {
|
|
|
|
|
Self::Close { data, close }
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-06-16 12:57:44 -07:00
|
|
|
/// Parse (F) message.
|
|
|
|
|
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
|
|
|
|
#[derive(Clone, Debug)]
|
|
|
|
|
pub struct Parse {
|
|
|
|
|
code: char,
|
|
|
|
|
#[allow(dead_code)]
|
|
|
|
|
len: i32,
|
|
|
|
|
pub name: String,
|
|
|
|
|
query: String,
|
|
|
|
|
num_params: i16,
|
|
|
|
|
param_types: Vec<i32>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl TryFrom<&BytesMut> for Parse {
|
|
|
|
|
type Error = Error;
|
|
|
|
|
|
|
|
|
|
fn try_from(buf: &BytesMut) -> Result<Parse, Error> {
|
|
|
|
|
let mut cursor = Cursor::new(buf);
|
|
|
|
|
let code = cursor.get_u8() as char;
|
|
|
|
|
let len = cursor.get_i32();
|
|
|
|
|
let name = cursor.read_string()?;
|
|
|
|
|
let query = cursor.read_string()?;
|
|
|
|
|
let num_params = cursor.get_i16();
|
|
|
|
|
let mut param_types = Vec::new();
|
|
|
|
|
|
|
|
|
|
for _ in 0..num_params {
|
|
|
|
|
param_types.push(cursor.get_i32());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok(Parse {
|
|
|
|
|
code,
|
|
|
|
|
len,
|
|
|
|
|
name,
|
|
|
|
|
query,
|
|
|
|
|
num_params,
|
|
|
|
|
param_types,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl TryFrom<Parse> for BytesMut {
|
|
|
|
|
type Error = Error;
|
|
|
|
|
|
|
|
|
|
fn try_from(parse: Parse) -> Result<BytesMut, Error> {
|
|
|
|
|
let mut bytes = BytesMut::new();
|
|
|
|
|
|
|
|
|
|
let name_binding = CString::new(parse.name)?;
|
|
|
|
|
let name = name_binding.as_bytes_with_nul();
|
|
|
|
|
|
|
|
|
|
let query_binding = CString::new(parse.query)?;
|
|
|
|
|
let query = query_binding.as_bytes_with_nul();
|
|
|
|
|
|
|
|
|
|
// Recompute length of the message.
|
|
|
|
|
let len = 4 // self
|
|
|
|
|
+ name.len()
|
|
|
|
|
+ query.len()
|
|
|
|
|
+ 2
|
|
|
|
|
+ 4 * parse.num_params as usize;
|
|
|
|
|
|
|
|
|
|
bytes.put_u8(parse.code as u8);
|
|
|
|
|
bytes.put_i32(len as i32);
|
|
|
|
|
bytes.put_slice(name);
|
|
|
|
|
bytes.put_slice(query);
|
|
|
|
|
bytes.put_i16(parse.num_params);
|
|
|
|
|
for param in parse.param_types {
|
|
|
|
|
bytes.put_i32(param);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok(bytes)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl TryFrom<&Parse> for BytesMut {
|
|
|
|
|
type Error = Error;
|
|
|
|
|
|
|
|
|
|
fn try_from(parse: &Parse) -> Result<BytesMut, Error> {
|
|
|
|
|
parse.clone().try_into()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Parse {
|
2023-10-25 18:11:57 -04:00
|
|
|
/// Renames the prepared statement to a new name based on the global counter
|
|
|
|
|
pub fn rewrite(mut self) -> Self {
|
|
|
|
|
self.name = format!(
|
|
|
|
|
"PGCAT_{}",
|
|
|
|
|
PREPARED_STATEMENT_COUNTER.fetch_add(1, Ordering::SeqCst)
|
|
|
|
|
);
|
2023-06-16 12:57:44 -07:00
|
|
|
self
|
|
|
|
|
}
|
|
|
|
|
|
2023-10-25 18:11:57 -04:00
|
|
|
/// Gets the name of the prepared statement from the buffer
|
|
|
|
|
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
|
|
|
|
|
let mut cursor = Cursor::new(buf);
|
|
|
|
|
// Skip the code and length
|
|
|
|
|
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
|
|
|
|
|
cursor.read_string()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Hashes the parse statement to be used as a key in the global cache
|
|
|
|
|
pub fn get_hash(&self) -> u64 {
|
|
|
|
|
// TODO_ZAIN: Take a look at which hashing function is being used
|
|
|
|
|
let mut hasher = DefaultHasher::new();
|
|
|
|
|
|
|
|
|
|
let concatenated = format!(
|
|
|
|
|
"{}{}{}",
|
|
|
|
|
self.query,
|
|
|
|
|
self.num_params,
|
|
|
|
|
self.param_types
|
|
|
|
|
.iter()
|
|
|
|
|
.map(ToString::to_string)
|
|
|
|
|
.collect::<Vec<_>>()
|
|
|
|
|
.join(",")
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
concatenated.hash(&mut hasher);
|
|
|
|
|
|
|
|
|
|
hasher.finish()
|
|
|
|
|
}
|
|
|
|
|
|
2023-06-16 12:57:44 -07:00
|
|
|
pub fn anonymous(&self) -> bool {
|
|
|
|
|
self.name.is_empty()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Bind (B) message.
|
|
|
|
|
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
|
|
|
|
#[derive(Clone, Debug)]
|
|
|
|
|
pub struct Bind {
|
|
|
|
|
code: char,
|
|
|
|
|
#[allow(dead_code)]
|
|
|
|
|
len: i64,
|
|
|
|
|
portal: String,
|
|
|
|
|
pub prepared_statement: String,
|
|
|
|
|
num_param_format_codes: i16,
|
|
|
|
|
param_format_codes: Vec<i16>,
|
|
|
|
|
num_param_values: i16,
|
|
|
|
|
param_values: Vec<(i32, BytesMut)>,
|
|
|
|
|
num_result_column_format_codes: i16,
|
|
|
|
|
result_columns_format_codes: Vec<i16>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl TryFrom<&BytesMut> for Bind {
|
|
|
|
|
type Error = Error;
|
|
|
|
|
|
|
|
|
|
fn try_from(buf: &BytesMut) -> Result<Bind, Error> {
|
|
|
|
|
let mut cursor = Cursor::new(buf);
|
|
|
|
|
let code = cursor.get_u8() as char;
|
|
|
|
|
let len = cursor.get_i32();
|
|
|
|
|
let portal = cursor.read_string()?;
|
|
|
|
|
let prepared_statement = cursor.read_string()?;
|
|
|
|
|
let num_param_format_codes = cursor.get_i16();
|
|
|
|
|
let mut param_format_codes = Vec::new();
|
|
|
|
|
|
|
|
|
|
for _ in 0..num_param_format_codes {
|
|
|
|
|
param_format_codes.push(cursor.get_i16());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let num_param_values = cursor.get_i16();
|
|
|
|
|
let mut param_values = Vec::new();
|
|
|
|
|
|
|
|
|
|
for _ in 0..num_param_values {
|
|
|
|
|
let param_len = cursor.get_i32();
|
2023-07-10 11:35:43 +03:00
|
|
|
// There is special occasion when the parameter is NULL
|
|
|
|
|
// In that case, param length is defined as -1
|
|
|
|
|
// So if the passed parameter len is over 0
|
|
|
|
|
if param_len > 0 {
|
|
|
|
|
let mut param = BytesMut::with_capacity(param_len as usize);
|
|
|
|
|
param.resize(param_len as usize, b'0');
|
|
|
|
|
cursor.copy_to_slice(&mut param);
|
|
|
|
|
// we push and the length and the parameter into vector
|
|
|
|
|
param_values.push((param_len, param));
|
|
|
|
|
} else {
|
|
|
|
|
// otherwise we push a tuple with -1 and 0-len BytesMut
|
|
|
|
|
// which means that after encountering -1 postgres proceeds
|
|
|
|
|
// to processing another parameter
|
|
|
|
|
param_values.push((param_len, BytesMut::new()));
|
|
|
|
|
}
|
2023-06-16 12:57:44 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let num_result_column_format_codes = cursor.get_i16();
|
|
|
|
|
let mut result_columns_format_codes = Vec::new();
|
|
|
|
|
|
|
|
|
|
for _ in 0..num_result_column_format_codes {
|
|
|
|
|
result_columns_format_codes.push(cursor.get_i16());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok(Bind {
|
|
|
|
|
code,
|
|
|
|
|
len: len as i64,
|
|
|
|
|
portal,
|
|
|
|
|
prepared_statement,
|
|
|
|
|
num_param_format_codes,
|
|
|
|
|
param_format_codes,
|
|
|
|
|
num_param_values,
|
|
|
|
|
param_values,
|
|
|
|
|
num_result_column_format_codes,
|
|
|
|
|
result_columns_format_codes,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl TryFrom<Bind> for BytesMut {
|
|
|
|
|
type Error = Error;
|
|
|
|
|
|
|
|
|
|
fn try_from(bind: Bind) -> Result<BytesMut, Error> {
|
|
|
|
|
let mut bytes = BytesMut::new();
|
|
|
|
|
|
|
|
|
|
let portal_binding = CString::new(bind.portal)?;
|
|
|
|
|
let portal = portal_binding.as_bytes_with_nul();
|
|
|
|
|
|
|
|
|
|
let prepared_statement_binding = CString::new(bind.prepared_statement)?;
|
|
|
|
|
let prepared_statement = prepared_statement_binding.as_bytes_with_nul();
|
|
|
|
|
|
|
|
|
|
let mut len = 4 // self
|
|
|
|
|
+ portal.len()
|
|
|
|
|
+ prepared_statement.len()
|
|
|
|
|
+ 2 // num_param_format_codes
|
|
|
|
|
+ 2 * bind.num_param_format_codes as usize // num_param_format_codes
|
|
|
|
|
+ 2; // num_param_values
|
|
|
|
|
|
|
|
|
|
for (param_len, _) in &bind.param_values {
|
|
|
|
|
len += 4 + *param_len as usize;
|
|
|
|
|
}
|
|
|
|
|
len += 2; // num_result_column_format_codes
|
|
|
|
|
len += 2 * bind.num_result_column_format_codes as usize;
|
|
|
|
|
|
|
|
|
|
bytes.put_u8(bind.code as u8);
|
|
|
|
|
bytes.put_i32(len as i32);
|
|
|
|
|
bytes.put_slice(portal);
|
|
|
|
|
bytes.put_slice(prepared_statement);
|
|
|
|
|
bytes.put_i16(bind.num_param_format_codes);
|
|
|
|
|
for param_format_code in bind.param_format_codes {
|
|
|
|
|
bytes.put_i16(param_format_code);
|
|
|
|
|
}
|
|
|
|
|
bytes.put_i16(bind.num_param_values);
|
|
|
|
|
for (param_len, param) in bind.param_values {
|
|
|
|
|
bytes.put_i32(param_len);
|
|
|
|
|
bytes.put_slice(¶m);
|
|
|
|
|
}
|
|
|
|
|
bytes.put_i16(bind.num_result_column_format_codes);
|
|
|
|
|
for result_column_format_code in bind.result_columns_format_codes {
|
|
|
|
|
bytes.put_i16(result_column_format_code);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok(bytes)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Bind {
|
2023-10-25 18:11:57 -04:00
|
|
|
/// Gets the name of the prepared statement from the buffer
|
|
|
|
|
pub fn get_name(buf: &BytesMut) -> Result<String, Error> {
|
|
|
|
|
let mut cursor = Cursor::new(buf);
|
|
|
|
|
// Skip the code and length
|
|
|
|
|
cursor.advance(mem::size_of::<u8>() + mem::size_of::<i32>());
|
|
|
|
|
cursor.read_string()?;
|
|
|
|
|
cursor.read_string()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Renames the prepared statement to a new name
|
|
|
|
|
pub fn rename(buf: BytesMut, new_name: &str) -> Result<BytesMut, Error> {
|
|
|
|
|
let mut cursor = Cursor::new(&buf);
|
|
|
|
|
// Read basic data from the cursor
|
|
|
|
|
let code = cursor.get_u8();
|
|
|
|
|
let current_len = cursor.get_i32();
|
|
|
|
|
let portal = cursor.read_string()?;
|
|
|
|
|
let prepared_statement = cursor.read_string()?;
|
|
|
|
|
|
|
|
|
|
// Calculate new length
|
|
|
|
|
let new_len = current_len + new_name.len() as i32 - prepared_statement.len() as i32;
|
|
|
|
|
|
|
|
|
|
// Begin building the response buffer
|
|
|
|
|
let mut response_buf = BytesMut::with_capacity(new_len as usize + 1);
|
|
|
|
|
response_buf.put_u8(code);
|
|
|
|
|
response_buf.put_i32(new_len);
|
|
|
|
|
|
|
|
|
|
// Put the portal and new name into the buffer
|
|
|
|
|
// Note: panic if the provided string contains null byte
|
|
|
|
|
response_buf.put_slice(CString::new(portal)?.as_bytes_with_nul());
|
|
|
|
|
response_buf.put_slice(CString::new(new_name)?.as_bytes_with_nul());
|
|
|
|
|
|
|
|
|
|
// Add the remainder of the original buffer into the response
|
|
|
|
|
response_buf.put_slice(&buf[cursor.position() as usize..]);
|
|
|
|
|
|
|
|
|
|
// Return the buffer
|
|
|
|
|
Ok(response_buf)
|
2023-06-16 12:57:44 -07:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn anonymous(&self) -> bool {
|
|
|
|
|
self.prepared_statement.is_empty()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, Clone)]
|
|
|
|
|
pub struct Describe {
|
|
|
|
|
code: char,
|
|
|
|
|
|
|
|
|
|
#[allow(dead_code)]
|
|
|
|
|
len: i32,
|
2023-11-09 01:36:45 +01:00
|
|
|
pub target: char,
|
2023-06-16 12:57:44 -07:00
|
|
|
pub statement_name: String,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl TryFrom<&BytesMut> for Describe {
|
|
|
|
|
type Error = Error;
|
|
|
|
|
|
|
|
|
|
fn try_from(bytes: &BytesMut) -> Result<Describe, Error> {
|
|
|
|
|
let mut cursor = Cursor::new(bytes);
|
|
|
|
|
let code = cursor.get_u8() as char;
|
|
|
|
|
let len = cursor.get_i32();
|
|
|
|
|
let target = cursor.get_u8() as char;
|
|
|
|
|
let statement_name = cursor.read_string()?;
|
|
|
|
|
|
|
|
|
|
Ok(Describe {
|
|
|
|
|
code,
|
|
|
|
|
len,
|
|
|
|
|
target,
|
|
|
|
|
statement_name,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl TryFrom<Describe> for BytesMut {
|
|
|
|
|
type Error = Error;
|
|
|
|
|
|
|
|
|
|
fn try_from(describe: Describe) -> Result<BytesMut, Error> {
|
|
|
|
|
let mut bytes = BytesMut::new();
|
|
|
|
|
let statement_name_binding = CString::new(describe.statement_name)?;
|
|
|
|
|
let statement_name = statement_name_binding.as_bytes_with_nul();
|
|
|
|
|
let len = 4 + 1 + statement_name.len();
|
|
|
|
|
|
|
|
|
|
bytes.put_u8(describe.code as u8);
|
|
|
|
|
bytes.put_i32(len as i32);
|
|
|
|
|
bytes.put_u8(describe.target as u8);
|
|
|
|
|
bytes.put_slice(statement_name);
|
|
|
|
|
|
|
|
|
|
Ok(bytes)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Describe {
|
2023-10-25 18:11:57 -04:00
|
|
|
pub fn empty_new() -> Describe {
|
|
|
|
|
Describe {
|
|
|
|
|
code: 'D',
|
|
|
|
|
len: 4 + 1 + 1,
|
|
|
|
|
target: 'S',
|
|
|
|
|
statement_name: "".to_string(),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-06-16 12:57:44 -07:00
|
|
|
pub fn rename(mut self, name: &str) -> Self {
|
|
|
|
|
self.statement_name = name.to_string();
|
|
|
|
|
self
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn anonymous(&self) -> bool {
|
|
|
|
|
self.statement_name.is_empty()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2023-06-18 23:02:34 -07:00
|
|
|
/// Close (F) message.
|
|
|
|
|
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
|
|
|
|
|
#[derive(Clone, Debug)]
|
|
|
|
|
pub struct Close {
|
|
|
|
|
code: char,
|
|
|
|
|
#[allow(dead_code)]
|
|
|
|
|
len: i32,
|
|
|
|
|
close_type: char,
|
|
|
|
|
pub name: String,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl TryFrom<&BytesMut> for Close {
|
|
|
|
|
type Error = Error;
|
|
|
|
|
|
|
|
|
|
fn try_from(bytes: &BytesMut) -> Result<Close, Error> {
|
|
|
|
|
let mut cursor = Cursor::new(bytes);
|
|
|
|
|
let code = cursor.get_u8() as char;
|
|
|
|
|
let len = cursor.get_i32();
|
|
|
|
|
let close_type = cursor.get_u8() as char;
|
|
|
|
|
let name = cursor.read_string()?;
|
|
|
|
|
|
|
|
|
|
Ok(Close {
|
|
|
|
|
code,
|
|
|
|
|
len,
|
|
|
|
|
close_type,
|
|
|
|
|
name,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl TryFrom<Close> for BytesMut {
|
|
|
|
|
type Error = Error;
|
|
|
|
|
|
|
|
|
|
fn try_from(close: Close) -> Result<BytesMut, Error> {
|
|
|
|
|
debug!("Close: {:?}", close);
|
|
|
|
|
|
|
|
|
|
let mut bytes = BytesMut::new();
|
|
|
|
|
let name_binding = CString::new(close.name)?;
|
|
|
|
|
let name = name_binding.as_bytes_with_nul();
|
|
|
|
|
let len = 4 + 1 + name.len();
|
|
|
|
|
|
|
|
|
|
bytes.put_u8(close.code as u8);
|
|
|
|
|
bytes.put_i32(len as i32);
|
|
|
|
|
bytes.put_u8(close.close_type as u8);
|
|
|
|
|
bytes.put_slice(name);
|
|
|
|
|
|
|
|
|
|
Ok(bytes)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl Close {
|
|
|
|
|
pub fn new(name: &str) -> Close {
|
|
|
|
|
let name = name.to_string();
|
|
|
|
|
|
|
|
|
|
Close {
|
|
|
|
|
code: 'C',
|
|
|
|
|
len: 4 + 1 + name.len() as i32 + 1, // will be recalculated
|
|
|
|
|
close_type: 'S',
|
|
|
|
|
name,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn is_prepared_statement(&self) -> bool {
|
|
|
|
|
self.close_type == 'S'
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn anonymous(&self) -> bool {
|
|
|
|
|
self.name.is_empty()
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub fn close_complete() -> BytesMut {
|
|
|
|
|
let mut bytes = BytesMut::new();
|
|
|
|
|
bytes.put_u8(b'3');
|
|
|
|
|
bytes.put_i32(4);
|
|
|
|
|
bytes
|
|
|
|
|
}
|
|
|
|
|
|
2023-08-09 13:14:05 -03:00
|
|
|
// from https://www.postgresql.org/docs/12/protocol-error-fields.html
|
|
|
|
|
#[derive(Debug, Default, PartialEq)]
|
|
|
|
|
pub struct PgErrorMsg {
|
|
|
|
|
pub severity_localized: String, // S
|
|
|
|
|
pub severity: String, // V
|
|
|
|
|
pub code: String, // C
|
|
|
|
|
pub message: String, // M
|
|
|
|
|
pub detail: Option<String>, // D
|
|
|
|
|
pub hint: Option<String>, // H
|
|
|
|
|
pub position: Option<u32>, // P
|
|
|
|
|
pub internal_position: Option<u32>, // p
|
|
|
|
|
pub internal_query: Option<String>, // q
|
|
|
|
|
pub where_context: Option<String>, // W
|
|
|
|
|
pub schema_name: Option<String>, // s
|
|
|
|
|
pub table_name: Option<String>, // t
|
|
|
|
|
pub column_name: Option<String>, // c
|
|
|
|
|
pub data_type_name: Option<String>, // d
|
|
|
|
|
pub constraint_name: Option<String>, // n
|
|
|
|
|
pub file_name: Option<String>, // F
|
|
|
|
|
pub line: Option<u32>, // L
|
|
|
|
|
pub routine: Option<String>, // R
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// TODO: implement with https://docs.rs/derive_more/latest/derive_more/
|
|
|
|
|
impl Display for PgErrorMsg {
|
|
|
|
|
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
|
|
|
write!(f, "[severity: {}]", self.severity)?;
|
|
|
|
|
write!(f, "[code: {}]", self.code)?;
|
|
|
|
|
write!(f, "[message: {}]", self.message)?;
|
|
|
|
|
if let Some(val) = &self.detail {
|
|
|
|
|
write!(f, "[detail: {val}]")?;
|
|
|
|
|
}
|
|
|
|
|
if let Some(val) = &self.hint {
|
|
|
|
|
write!(f, "[hint: {val}]")?;
|
|
|
|
|
}
|
|
|
|
|
if let Some(val) = &self.position {
|
|
|
|
|
write!(f, "[position: {val}]")?;
|
|
|
|
|
}
|
|
|
|
|
if let Some(val) = &self.internal_position {
|
|
|
|
|
write!(f, "[internal_position: {val}]")?;
|
|
|
|
|
}
|
|
|
|
|
if let Some(val) = &self.internal_query {
|
|
|
|
|
write!(f, "[internal_query: {val}]")?;
|
|
|
|
|
}
|
|
|
|
|
if let Some(val) = &self.internal_query {
|
|
|
|
|
write!(f, "[internal_query: {val}]")?;
|
|
|
|
|
}
|
|
|
|
|
if let Some(val) = &self.where_context {
|
|
|
|
|
write!(f, "[where: {val}]")?;
|
|
|
|
|
}
|
|
|
|
|
if let Some(val) = &self.schema_name {
|
|
|
|
|
write!(f, "[schema_name: {val}]")?;
|
|
|
|
|
}
|
|
|
|
|
if let Some(val) = &self.table_name {
|
|
|
|
|
write!(f, "[table_name: {val}]")?;
|
|
|
|
|
}
|
|
|
|
|
if let Some(val) = &self.column_name {
|
|
|
|
|
write!(f, "[column_name: {val}]")?;
|
|
|
|
|
}
|
|
|
|
|
if let Some(val) = &self.data_type_name {
|
|
|
|
|
write!(f, "[data_type_name: {val}]")?;
|
|
|
|
|
}
|
|
|
|
|
if let Some(val) = &self.constraint_name {
|
|
|
|
|
write!(f, "[constraint_name: {val}]")?;
|
|
|
|
|
}
|
|
|
|
|
if let Some(val) = &self.file_name {
|
|
|
|
|
write!(f, "[file_name: {val}]")?;
|
|
|
|
|
}
|
|
|
|
|
if let Some(val) = &self.line {
|
|
|
|
|
write!(f, "[line: {val}]")?;
|
|
|
|
|
}
|
|
|
|
|
if let Some(val) = &self.routine {
|
|
|
|
|
write!(f, "[routine: {val}]")?;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
write!(f, " ")?;
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl PgErrorMsg {
|
2023-10-25 18:11:57 -04:00
|
|
|
pub fn parse(error_msg: &[u8]) -> Result<PgErrorMsg, Error> {
|
2023-08-09 13:14:05 -03:00
|
|
|
let mut out = PgErrorMsg {
|
|
|
|
|
severity_localized: "".to_string(),
|
|
|
|
|
severity: "".to_string(),
|
|
|
|
|
code: "".to_string(),
|
|
|
|
|
message: "".to_string(),
|
|
|
|
|
detail: None,
|
|
|
|
|
hint: None,
|
|
|
|
|
position: None,
|
|
|
|
|
internal_position: None,
|
|
|
|
|
internal_query: None,
|
|
|
|
|
where_context: None,
|
|
|
|
|
schema_name: None,
|
|
|
|
|
table_name: None,
|
|
|
|
|
column_name: None,
|
|
|
|
|
data_type_name: None,
|
|
|
|
|
constraint_name: None,
|
|
|
|
|
file_name: None,
|
|
|
|
|
line: None,
|
|
|
|
|
routine: None,
|
|
|
|
|
};
|
|
|
|
|
for msg_part in error_msg.split(|v| *v == MESSAGE_TERMINATOR) {
|
|
|
|
|
if msg_part.is_empty() {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
let msg_content = match String::from_utf8_lossy(&msg_part[1..]).parse() {
|
|
|
|
|
Ok(c) => c,
|
|
|
|
|
Err(err) => {
|
|
|
|
|
return Err(Error::ServerMessageParserError(format!(
|
|
|
|
|
"could not parse server message field. err {:?}",
|
|
|
|
|
err
|
|
|
|
|
)))
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
match &msg_part[0] {
|
|
|
|
|
b'S' => {
|
|
|
|
|
out.severity_localized = msg_content;
|
|
|
|
|
}
|
|
|
|
|
b'V' => {
|
|
|
|
|
out.severity = msg_content;
|
|
|
|
|
}
|
|
|
|
|
b'C' => {
|
|
|
|
|
out.code = msg_content;
|
|
|
|
|
}
|
|
|
|
|
b'M' => {
|
|
|
|
|
out.message = msg_content;
|
|
|
|
|
}
|
|
|
|
|
b'D' => {
|
|
|
|
|
out.detail = Some(msg_content);
|
|
|
|
|
}
|
|
|
|
|
b'H' => {
|
|
|
|
|
out.hint = Some(msg_content);
|
|
|
|
|
}
|
|
|
|
|
b'P' => out.position = Some(u32::from_str(msg_content.as_str()).unwrap_or(0)),
|
|
|
|
|
b'p' => {
|
|
|
|
|
out.internal_position = Some(u32::from_str(msg_content.as_str()).unwrap_or(0))
|
|
|
|
|
}
|
|
|
|
|
b'q' => {
|
|
|
|
|
out.internal_query = Some(msg_content);
|
|
|
|
|
}
|
|
|
|
|
b'W' => {
|
|
|
|
|
out.where_context = Some(msg_content);
|
|
|
|
|
}
|
|
|
|
|
b's' => {
|
|
|
|
|
out.schema_name = Some(msg_content);
|
|
|
|
|
}
|
|
|
|
|
b't' => {
|
|
|
|
|
out.table_name = Some(msg_content);
|
|
|
|
|
}
|
|
|
|
|
b'c' => {
|
|
|
|
|
out.column_name = Some(msg_content);
|
|
|
|
|
}
|
|
|
|
|
b'd' => {
|
|
|
|
|
out.data_type_name = Some(msg_content);
|
|
|
|
|
}
|
|
|
|
|
b'n' => {
|
|
|
|
|
out.constraint_name = Some(msg_content);
|
|
|
|
|
}
|
|
|
|
|
b'F' => {
|
|
|
|
|
out.file_name = Some(msg_content);
|
|
|
|
|
}
|
|
|
|
|
b'L' => out.line = Some(u32::from_str(msg_content.as_str()).unwrap_or(0)),
|
|
|
|
|
b'R' => {
|
|
|
|
|
out.routine = Some(msg_content);
|
|
|
|
|
}
|
|
|
|
|
_ => {}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Ok(out)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[cfg(test)]
|
|
|
|
|
mod tests {
|
|
|
|
|
use crate::messages::PgErrorMsg;
|
|
|
|
|
use log::{error, info};
|
|
|
|
|
|
|
|
|
|
fn field(kind: char, content: &str) -> Vec<u8> {
|
|
|
|
|
format!("{kind}{content}\0").as_bytes().to_vec()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
fn parse_fields() {
|
|
|
|
|
let mut complete_msg = vec![];
|
|
|
|
|
let severity = "FATAL";
|
2023-10-10 09:18:21 -07:00
|
|
|
complete_msg.extend(field('S', severity));
|
|
|
|
|
complete_msg.extend(field('V', severity));
|
2023-08-09 13:14:05 -03:00
|
|
|
|
|
|
|
|
let error_code = "29P02";
|
2023-10-10 09:18:21 -07:00
|
|
|
complete_msg.extend(field('C', error_code));
|
2023-08-09 13:14:05 -03:00
|
|
|
let message = "password authentication failed for user \"wrong_user\"";
|
2023-10-10 09:18:21 -07:00
|
|
|
complete_msg.extend(field('M', message));
|
2023-08-09 13:14:05 -03:00
|
|
|
let detail_msg = "super detailed message";
|
2023-10-10 09:18:21 -07:00
|
|
|
complete_msg.extend(field('D', detail_msg));
|
2023-08-09 13:14:05 -03:00
|
|
|
let hint_msg = "hint detail here";
|
2023-10-10 09:18:21 -07:00
|
|
|
complete_msg.extend(field('H', hint_msg));
|
2023-08-09 13:14:05 -03:00
|
|
|
complete_msg.extend(field('P', "123"));
|
|
|
|
|
complete_msg.extend(field('p', "234"));
|
|
|
|
|
let internal_query = "SELECT * from foo;";
|
2023-10-10 09:18:21 -07:00
|
|
|
complete_msg.extend(field('q', internal_query));
|
2023-08-09 13:14:05 -03:00
|
|
|
let where_msg = "where goes here";
|
2023-10-10 09:18:21 -07:00
|
|
|
complete_msg.extend(field('W', where_msg));
|
2023-08-09 13:14:05 -03:00
|
|
|
let schema_msg = "schema_name";
|
2023-10-10 09:18:21 -07:00
|
|
|
complete_msg.extend(field('s', schema_msg));
|
2023-08-09 13:14:05 -03:00
|
|
|
let table_msg = "table_name";
|
2023-10-10 09:18:21 -07:00
|
|
|
complete_msg.extend(field('t', table_msg));
|
2023-08-09 13:14:05 -03:00
|
|
|
let column_msg = "column_name";
|
2023-10-10 09:18:21 -07:00
|
|
|
complete_msg.extend(field('c', column_msg));
|
2023-08-09 13:14:05 -03:00
|
|
|
let data_type_msg = "type_name";
|
2023-10-10 09:18:21 -07:00
|
|
|
complete_msg.extend(field('d', data_type_msg));
|
2023-08-09 13:14:05 -03:00
|
|
|
let constraint_msg = "constraint_name";
|
2023-10-10 09:18:21 -07:00
|
|
|
complete_msg.extend(field('n', constraint_msg));
|
2023-08-09 13:14:05 -03:00
|
|
|
let file_msg = "pgcat.c";
|
2023-10-10 09:18:21 -07:00
|
|
|
complete_msg.extend(field('F', file_msg));
|
2023-08-09 13:14:05 -03:00
|
|
|
complete_msg.extend(field('L', "335"));
|
|
|
|
|
let routine_msg = "my_failing_routine";
|
2023-10-10 09:18:21 -07:00
|
|
|
complete_msg.extend(field('R', routine_msg));
|
2023-08-09 13:14:05 -03:00
|
|
|
|
|
|
|
|
tracing_subscriber::fmt()
|
|
|
|
|
.with_max_level(tracing::Level::INFO)
|
|
|
|
|
.with_ansi(true)
|
|
|
|
|
.init();
|
|
|
|
|
|
|
|
|
|
info!(
|
|
|
|
|
"full message: {}",
|
2023-10-25 18:11:57 -04:00
|
|
|
PgErrorMsg::parse(&complete_msg).unwrap()
|
2023-08-09 13:14:05 -03:00
|
|
|
);
|
|
|
|
|
assert_eq!(
|
|
|
|
|
PgErrorMsg {
|
|
|
|
|
severity_localized: severity.to_string(),
|
|
|
|
|
severity: severity.to_string(),
|
|
|
|
|
code: error_code.to_string(),
|
|
|
|
|
message: message.to_string(),
|
|
|
|
|
detail: Some(detail_msg.to_string()),
|
|
|
|
|
hint: Some(hint_msg.to_string()),
|
|
|
|
|
position: Some(123),
|
|
|
|
|
internal_position: Some(234),
|
|
|
|
|
internal_query: Some(internal_query.to_string()),
|
|
|
|
|
where_context: Some(where_msg.to_string()),
|
|
|
|
|
schema_name: Some(schema_msg.to_string()),
|
|
|
|
|
table_name: Some(table_msg.to_string()),
|
|
|
|
|
column_name: Some(column_msg.to_string()),
|
|
|
|
|
data_type_name: Some(data_type_msg.to_string()),
|
|
|
|
|
constraint_name: Some(constraint_msg.to_string()),
|
|
|
|
|
file_name: Some(file_msg.to_string()),
|
|
|
|
|
line: Some(335),
|
|
|
|
|
routine: Some(routine_msg.to_string()),
|
|
|
|
|
},
|
2023-10-25 18:11:57 -04:00
|
|
|
PgErrorMsg::parse(&complete_msg).unwrap()
|
2023-08-09 13:14:05 -03:00
|
|
|
);
|
|
|
|
|
|
|
|
|
|
let mut only_mandatory_msg = vec![];
|
2023-10-10 09:18:21 -07:00
|
|
|
only_mandatory_msg.extend(field('S', severity));
|
|
|
|
|
only_mandatory_msg.extend(field('V', severity));
|
|
|
|
|
only_mandatory_msg.extend(field('C', error_code));
|
|
|
|
|
only_mandatory_msg.extend(field('M', message));
|
|
|
|
|
only_mandatory_msg.extend(field('D', detail_msg));
|
2023-08-09 13:14:05 -03:00
|
|
|
|
2023-10-25 18:11:57 -04:00
|
|
|
let err_fields = PgErrorMsg::parse(&only_mandatory_msg).unwrap();
|
2023-08-09 13:14:05 -03:00
|
|
|
info!("only mandatory fields: {}", &err_fields);
|
|
|
|
|
error!(
|
|
|
|
|
"server error: {}: {}",
|
|
|
|
|
err_fields.severity, err_fields.message
|
|
|
|
|
);
|
|
|
|
|
assert_eq!(
|
|
|
|
|
PgErrorMsg {
|
|
|
|
|
severity_localized: severity.to_string(),
|
|
|
|
|
severity: severity.to_string(),
|
|
|
|
|
code: error_code.to_string(),
|
|
|
|
|
message: message.to_string(),
|
|
|
|
|
detail: Some(detail_msg.to_string()),
|
|
|
|
|
hint: None,
|
|
|
|
|
position: None,
|
|
|
|
|
internal_position: None,
|
|
|
|
|
internal_query: None,
|
|
|
|
|
where_context: None,
|
|
|
|
|
schema_name: None,
|
|
|
|
|
table_name: None,
|
|
|
|
|
column_name: None,
|
|
|
|
|
data_type_name: None,
|
|
|
|
|
constraint_name: None,
|
|
|
|
|
file_name: None,
|
|
|
|
|
line: None,
|
|
|
|
|
routine: None,
|
|
|
|
|
},
|
2023-10-25 18:11:57 -04:00
|
|
|
PgErrorMsg::parse(&only_mandatory_msg).unwrap()
|
2023-08-09 13:14:05 -03:00
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
}
|