Constants, comments, CI fixes, dead code clean-up (#21)

* constants

* server.rs docs

* client.rs comments

* dead code; comments

* comment

* query cancellation comments

* remove unnecessary cast

* move db setup up one step

* query cancellation test

* new line; good night
This commit is contained in:
Lev Kokotov
2022-02-15 22:45:45 -08:00
committed by GitHub
parent bb84dcee64
commit 7b0ceefb96
8 changed files with 208 additions and 115 deletions

View File

@@ -3,12 +3,12 @@
set -e set -e
set -o xtrace set -o xtrace
psql -e -h 127.0.0.1 -p 5432 -U postgres -f tests/sharding/query_routing_setup.sql
./target/debug/pgcat & ./target/debug/pgcat &
sleep 1 sleep 1
psql -e -h 127.0.0.1 -p 5432 -U postgres -f tests/sharding/query_routing_setup.sql
# Setup PgBench # Setup PgBench
pgbench -i -h 127.0.0.1 -p 6432 pgbench -i -h 127.0.0.1 -p 6432
@@ -18,6 +18,13 @@ pgbench -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol simple
# Extended protocol # Extended protocol
pgbench -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol extended pgbench -h 127.0.0.1 -p 6432 -t 500 -c 2 --protocol extended
# COPY TO STDOUT test
psql -h 127.0.0.1 -p 6432 -c 'COPY (SELECT * FROM pgbench_accounts LIMIT 15) TO STDOUT;' > /dev/null
# Query cancellation test
(psql -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(5)' || true) &
killall psql -s SIGINT
# Sharding insert # Sharding insert
psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_insert.sql psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_insert.sql
@@ -29,3 +36,6 @@ psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_primary_replic
# Attempt clean shut down # Attempt clean shut down
killall pgcat -s SIGINT killall pgcat -s SIGINT
# Allow for graceful shutdown
sleep 1

View File

@@ -5,12 +5,15 @@ use bytes::{Buf, BufMut, BytesMut};
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use regex::Regex; use regex::Regex;
use tokio::io::{AsyncReadExt, BufReader}; use tokio::io::{AsyncReadExt, BufReader};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::{
use tokio::net::TcpStream; tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
};
use std::collections::HashMap; use std::collections::HashMap;
use crate::config::Role; use crate::config::Role;
use crate::constants::*;
use crate::errors::Error; use crate::errors::Error;
use crate::messages::*; use crate::messages::*;
use crate::pool::{ClientServerMap, ConnectionPool}; use crate::pool::{ClientServerMap, ConnectionPool};
@@ -97,7 +100,7 @@ impl Client {
match code { match code {
// Client wants SSL. We don't support it at the moment. // Client wants SSL. We don't support it at the moment.
80877103 => { SSL_REQUEST_CODE => {
let mut no = BytesMut::with_capacity(1); let mut no = BytesMut::with_capacity(1);
no.put_u8(b'N'); no.put_u8(b'N');
@@ -105,7 +108,7 @@ impl Client {
} }
// Regular startup message. // Regular startup message.
196608 => { PROTOCOL_VERSION_NUMBER => {
// TODO: perform actual auth. // TODO: perform actual auth.
let parameters = parse_startup(bytes.clone())?; let parameters = parse_startup(bytes.clone())?;
@@ -138,7 +141,7 @@ impl Client {
} }
// Query cancel request. // Query cancel request.
80877102 => { CANCEL_REQUEST_CODE => {
let (read, write) = stream.into_split(); let (read, write) = stream.into_split();
let process_id = bytes.get_i32(); let process_id = bytes.get_i32();
@@ -168,23 +171,31 @@ impl Client {
/// Client loop. We handle all messages between the client and the database here. /// Client loop. We handle all messages between the client and the database here.
pub async fn handle(&mut self, mut pool: ConnectionPool) -> Result<(), Error> { pub async fn handle(&mut self, mut pool: ConnectionPool) -> Result<(), Error> {
// Special: cancelling existing running query // The client wants to cancel a query it has issued previously.
if self.cancel_mode { if self.cancel_mode {
let (process_id, secret_key, address, port) = { let (process_id, secret_key, address, port) = {
let guard = self.client_server_map.lock().unwrap(); let guard = self.client_server_map.lock().unwrap();
match guard.get(&(self.process_id, self.secret_key)) { match guard.get(&(self.process_id, self.secret_key)) {
// Drop the mutex as soon as possible. // Drop the mutex as soon as possible.
// We found the server the client is using for its query
// that it wants to cancel.
Some((process_id, secret_key, address, port)) => ( Some((process_id, secret_key, address, port)) => (
process_id.clone(), process_id.clone(),
secret_key.clone(), secret_key.clone(),
address.clone(), address.clone(),
port.clone(), port.clone(),
), ),
// The client doesn't know / got the wrong server,
// we're closing the connection for security reasons.
None => return Ok(()), None => return Ok(()),
} }
}; };
// TODO: pass actual server host and port somewhere. // Opens a new separate connection to the server, sends the backend_id
// and secret_key and then closes it for security reasons. No other interactions
// take place.
return Ok(Server::cancel(&address, &port, process_id, secret_key).await?); return Ok(Server::cancel(&address, &port, process_id, secret_key).await?);
} }
@@ -217,7 +228,7 @@ impl Client {
}; };
// Parse for special server role selection command. // Parse for special server role selection command.
// // SET SERVER ROLE TO '(primary|replica)';
match self.select_role(message.clone()) { match self.select_role(message.clone()) {
Some(r) => { Some(r) => {
custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?; custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?;
@@ -236,15 +247,17 @@ impl Client {
} }
}; };
let mut proxy = connection.0; let mut reference = connection.0;
let _address = connection.1; let _address = connection.1;
let server = &mut *proxy; let server = &mut *reference;
// Claim this server as mine for query cancellation. // Claim this server as mine for query cancellation.
server.claim(self.process_id, self.secret_key); server.claim(self.process_id, self.secret_key);
// Transaction loop. Multiple queries can be issued by the client here.
// The connection belongs to the client until the transaction is over,
// or until the client disconnects if we are in session mode.
loop { loop {
// No messages in the buffer, read one.
let mut message = if message.len() == 0 { let mut message = if message.len() == 0 {
match read_message(&mut self.read).await { match read_message(&mut self.read).await {
Ok(message) => message, Ok(message) => message,
@@ -268,19 +281,26 @@ impl Client {
msg msg
}; };
let original = message.clone(); // To be forwarded to the server // The message will be forwarded to the server intact. We still would like to
// parse it below to figure out what to do with it.
let original = message.clone();
let code = message.get_u8() as char; let code = message.get_u8() as char;
let _len = message.get_i32() as usize; let _len = message.get_i32() as usize;
match code { match code {
// ReadyForQuery
'Q' => { 'Q' => {
// TODO: implement retries here for read-only transactions. // TODO: implement retries here for read-only transactions.
server.send(original).await?; server.send(original).await?;
// Read all data the server has to offer, which can be multiple messages
// buffered in 8196 bytes chunks.
loop { loop {
// TODO: implement retries here for read-only transactions. // TODO: implement retries here for read-only transactions.
let response = server.recv().await?; let response = server.recv().await?;
// Send server reply to the client.
match write_all_half(&mut self.write, response).await { match write_all_half(&mut self.write, response).await {
Ok(_) => (), Ok(_) => (),
Err(err) => { Err(err) => {
@@ -294,15 +314,18 @@ impl Client {
} }
} }
// Send statistic // Report query executed statistics.
self.stats.query(); self.stats.query();
// Transaction over // The transaction is over, we can release the connection back to the pool.
if !server.in_transaction() { if !server.in_transaction() {
// Report transaction executed statistics.
self.stats.transaction(); self.stats.transaction();
// Release server // Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode { if self.transaction_mode {
// Report this client as idle.
self.stats.client_idle(); self.stats.client_idle();
shard = None; shard = None;
@@ -313,6 +336,7 @@ impl Client {
} }
} }
// Terminate
'X' => { 'X' => {
// Client closing. Rollback and clean up // Client closing. Rollback and clean up
// connection before releasing into the pool. // connection before releasing into the pool.
@@ -326,35 +350,46 @@ impl Client {
return Ok(()); return Ok(());
} }
// Parse
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
'P' => { 'P' => {
// Extended protocol, let's buffer most of it
self.buffer.put(&original[..]); self.buffer.put(&original[..]);
} }
// Bind
// The placeholder's replacements are here, e.g. 'user@email.com' and 'true'
'B' => { 'B' => {
self.buffer.put(&original[..]); self.buffer.put(&original[..]);
} }
// Describe // Describe
// Command a client can issue to describe a previously prepared named statement.
'D' => { 'D' => {
self.buffer.put(&original[..]); self.buffer.put(&original[..]);
} }
// Execute
// Execute a prepared statement prepared in `P` and bound in `B`.
'E' => { 'E' => {
self.buffer.put(&original[..]); self.buffer.put(&original[..]);
} }
// Sync
// Frontend (client) is asking for the query result now.
'S' => { 'S' => {
// Extended protocol, client requests sync
self.buffer.put(&original[..]); self.buffer.put(&original[..]);
// TODO: retries for read-only transactions // TODO: retries for read-only transactions.
server.send(self.buffer.clone()).await?; server.send(self.buffer.clone()).await?;
self.buffer.clear(); self.buffer.clear();
// Read all data the server has to offer, which can be multiple messages
// buffered in 8196 bytes chunks.
loop { loop {
// TODO: retries for read-only transactions // TODO: retries for read-only transactions
let response = server.recv().await?; let response = server.recv().await?;
match write_all_half(&mut self.write, response).await { match write_all_half(&mut self.write, response).await {
Ok(_) => (), Ok(_) => (),
Err(err) => { Err(err) => {
@@ -368,9 +403,11 @@ impl Client {
} }
} }
// Report query executed statistics.
self.stats.query(); self.stats.query();
// Release server // Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if !server.in_transaction() { if !server.in_transaction() {
self.stats.transaction(); self.stats.transaction();
@@ -392,10 +429,13 @@ impl Client {
server.send(original).await?; server.send(original).await?;
} }
// CopyDone or CopyFail
// Copy is done, successfully or not.
'c' | 'f' => { 'c' | 'f' => {
// Copy is done.
server.send(original).await?; server.send(original).await?;
let response = server.recv().await?; let response = server.recv().await?;
match write_all_half(&mut self.write, response).await { match write_all_half(&mut self.write, response).await {
Ok(_) => (), Ok(_) => (),
Err(err) => { Err(err) => {
@@ -404,24 +444,29 @@ impl Client {
} }
}; };
// Release the server // Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if !server.in_transaction() { if !server.in_transaction() {
self.stats.transaction(); self.stats.transaction();
if self.transaction_mode { if self.transaction_mode {
shard = None; shard = None;
role = self.default_server_role; role = self.default_server_role;
break; break;
} }
} }
} }
// Some unexpected message. We either did not implement the protocol correctly
// or this is not a Postgres client we're talking to.
_ => { _ => {
println!(">>> Unexpected code: {}", code); println!(">>> Unexpected code: {}", code);
} }
} }
} }
// The server is no longer bound to us, we can't cancel it's queries anymore.
self.release(); self.release();
} }
} }
@@ -450,6 +495,7 @@ impl Client {
let len = buf.get_i32(); let len = buf.get_i32();
let query = String::from_utf8_lossy(&buf[..len as usize - 4 - 1]).to_ascii_uppercase(); // Don't read the ternminating null let query = String::from_utf8_lossy(&buf[..len as usize - 4 - 1]).to_ascii_uppercase(); // Don't read the ternminating null
let rgx = match SHARDING_REGEX_RE.get() { let rgx = match SHARDING_REGEX_RE.get() {
Some(r) => r, Some(r) => r,
None => return None, None => return None,
@@ -457,11 +503,13 @@ impl Client {
if rgx.is_match(&query) { if rgx.is_match(&query) {
let shard = query.split("'").collect::<Vec<&str>>()[1]; let shard = query.split("'").collect::<Vec<&str>>()[1];
match shard.parse::<i64>() { match shard.parse::<i64>() {
Ok(shard) => { Ok(shard) => {
let sharder = Sharder::new(shards); let sharder = Sharder::new(shards);
Some(sharder.pg_bigint_hash(shard)) Some(sharder.pg_bigint_hash(shard))
} }
Err(_) => None, Err(_) => None,
} }
} else { } else {
@@ -481,6 +529,7 @@ impl Client {
let len = buf.get_i32(); let len = buf.get_i32();
let query = String::from_utf8_lossy(&buf[..len as usize - 4 - 1]).to_ascii_uppercase(); let query = String::from_utf8_lossy(&buf[..len as usize - 4 - 1]).to_ascii_uppercase();
let rgx = match ROLE_REGEX_RE.get() { let rgx = match ROLE_REGEX_RE.get() {
Some(r) => r, Some(r) => r,
None => return None, None => return None,
@@ -490,6 +539,7 @@ impl Client {
// it'll be time to abstract :). // it'll be time to abstract :).
if rgx.is_match(&query) { if rgx.is_match(&query) {
let role = query.split("'").collect::<Vec<&str>>()[1]; let role = query.split("'").collect::<Vec<&str>>()[1];
match role { match role {
"PRIMARY" => Some(Role::Primary), "PRIMARY" => Some(Role::Primary),
"REPLICA" => Some(Role::Replica), "REPLICA" => Some(Role::Replica),

22
src/constants.rs Normal file
View File

@@ -0,0 +1,22 @@
/// Various protocol constants, as defined in
/// https://www.postgresql.org/docs/12/protocol-message-formats.html
/// and elsewhere in the source code.
/// Also other constants we use elsewhere.
// Used in the StartupMessage to indicate regular handshake.
pub const PROTOCOL_VERSION_NUMBER: i32 = 196608;
// SSLRequest: used to indicate we want an SSL connection.
pub const SSL_REQUEST_CODE: i32 = 80877103;
// CancelRequest: the cancel request code.
pub const CANCEL_REQUEST_CODE: i32 = 80877102;
// AuthenticationMD5Password
pub const MD5_ENCRYPTED_PASSWORD: i32 = 5;
// AuthenticationOk
pub const AUTHENTICATION_SUCCESSFUL: i32 = 0;
// ErrorResponse: A code identifying the field type; if zero, this is the message terminator and no string follows.
pub const MESSAGE_TERMINATOR: u8 = 0;

View File

@@ -35,6 +35,7 @@ use tokio::sync::mpsc;
mod client; mod client;
mod config; mod config;
mod constants;
mod errors; mod errors;
mod messages; mod messages;
mod pool; mod pool;

View File

@@ -1,18 +1,17 @@
/// Helper functions to send one-off protocol messages
/// and handle TcpStream (TCP socket).
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use md5::{Digest, Md5}; use md5::{Digest, Md5};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::{
use tokio::net::TcpStream; tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
};
use std::collections::HashMap; use std::collections::HashMap;
use crate::errors::Error; use crate::errors::Error;
// This is a funny one. `psql` parses this to figure out which
// queries to send when using shortcuts, e.g. \d+.
// No longer used. Keeping it here until I'm sure we don't need it again.
const _SERVER_VESION: &str = "12.9 (Ubuntu 12.9-0ubuntu0.20.04.1)";
/// Tell the client that authentication handshake completed successfully. /// Tell the client that authentication handshake completed successfully.
pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> { pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
let mut auth_ok = BytesMut::with_capacity(9); let mut auth_ok = BytesMut::with_capacity(9);
@@ -24,32 +23,6 @@ pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
Ok(write_all(stream, auth_ok).await?) Ok(write_all(stream, auth_ok).await?)
} }
/// Send server parameters to the client. This will tell the client
/// what server version and what's the encoding we're using.
//
// No longer used. Keeping it here until I'm sure we don't need it again.
//
pub async fn _server_parameters(stream: &mut TcpStream) -> Result<(), Error> {
let client_encoding = BytesMut::from(&b"client_encoding\0UTF8\0"[..]);
let server_version =
BytesMut::from(&format!("server_version\0{}\0", _SERVER_VESION).as_bytes()[..]);
// Client encoding
let len = client_encoding.len() as i32 + 4; // TODO: add more parameters here
let mut res = BytesMut::with_capacity(64);
res.put_u8(b'S');
res.put_i32(len);
res.put_slice(&client_encoding[..]);
let len = server_version.len() as i32 + 4;
res.put_u8(b'S');
res.put_i32(len);
res.put_slice(&server_version[..]);
Ok(write_all(stream, res).await?)
}
/// Give the client the process_id and secret we generated /// Give the client the process_id and secret we generated
/// used in query cancellation. /// used in query cancellation.
pub async fn backend_key_data( pub async fn backend_key_data(
@@ -179,6 +152,7 @@ pub async fn md5_password(
password.push(0); password.push(0);
let mut message = BytesMut::with_capacity(password.len() as usize + 5); let mut message = BytesMut::with_capacity(password.len() as usize + 5);
message.put_u8(b'p'); message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4); message.put_i32(password.len() as i32 + 4);
message.put_slice(&password[..]); message.put_slice(&password[..]);

View File

@@ -10,16 +10,11 @@ use crate::server::Server;
use crate::stats::Reporter; use crate::stats::Reporter;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{ use std::sync::{Arc, Mutex};
// atomic::{AtomicUsize, Ordering},
Arc,
Mutex,
};
use std::time::Instant; use std::time::Instant;
// Banlist: bad servers go in here. // Banlist: bad servers go in here.
pub type BanList = Arc<Mutex<Vec<HashMap<Address, NaiveDateTime>>>>; pub type BanList = Arc<Mutex<Vec<HashMap<Address, NaiveDateTime>>>>;
// pub type Counter = Arc<AtomicUsize>;
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>; pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>;
#[derive(Clone, Debug)] #[derive(Clone, Debug)]

View File

@@ -1,14 +1,14 @@
#![allow(dead_code)]
#![allow(unused_variables)]
///! Implementation of the PostgreSQL server (database) protocol. ///! Implementation of the PostgreSQL server (database) protocol.
///! Here we are pretending to the a Postgres client. ///! Here we are pretending to the a Postgres client.
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use tokio::io::{AsyncReadExt, BufReader}; use tokio::io::{AsyncReadExt, BufReader};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::{
use tokio::net::TcpStream; tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
};
use crate::config::{Address, User}; use crate::config::{Address, User};
use crate::constants::*;
use crate::errors::Error; use crate::errors::Error;
use crate::messages::*; use crate::messages::*;
use crate::stats::Reporter; use crate::stats::Reporter;
@@ -20,23 +20,23 @@ pub struct Server {
// port, e.g. 5432, and role, e.g. primary or replica. // port, e.g. 5432, and role, e.g. primary or replica.
address: Address, address: Address,
// Buffered read socket // Buffered read socket.
read: BufReader<OwnedReadHalf>, read: BufReader<OwnedReadHalf>,
// Unbuffered write socket (our client code buffers) // Unbuffered write socket (our client code buffers).
write: OwnedWriteHalf, write: OwnedWriteHalf,
// Our server response buffer // Our server response buffer. We buffer data before we give it to the client.
buffer: BytesMut, buffer: BytesMut,
// Server information the server sent us over on startup // Server information the server sent us over on startup.
server_info: BytesMut, server_info: BytesMut,
// Backend id and secret key used for query cancellation. // Backend id and secret key used for query cancellation.
backend_id: i32, backend_id: i32,
secret_key: i32, secret_key: i32,
// Is the server inside a transaction at the moment. // Is the server inside a transaction or idle.
in_transaction: bool, in_transaction: bool,
// Is there more data for the client to read. // Is there more data for the client to read.
@@ -48,16 +48,16 @@ pub struct Server {
// Mapping of clients and servers used for query cancellation. // Mapping of clients and servers used for query cancellation.
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
// Server connected at // Server connected at.
connected_at: chrono::naive::NaiveDateTime, connected_at: chrono::naive::NaiveDateTime,
// Stats // Reports various metrics, e.g. data sent & received.
stats: Reporter, stats: Reporter,
} }
impl Server { impl Server {
/// Pretend to be the Postgres client and connect to the server given host, port and credentials. /// Pretend to be the Postgres client and connect to the server given host, port and credentials.
/// Perform the authentication and return the server in a ready-for-query mode. /// Perform the authentication and return the server in a ready for query state.
pub async fn startup( pub async fn startup(
address: &Address, address: &Address,
user: &User, user: &User,
@@ -74,13 +74,15 @@ impl Server {
} }
}; };
// Send the startup packet. // Send the startup packet telling the server we're a normal Postgres client.
startup(&mut stream, &user.name, database).await?; startup(&mut stream, &user.name, database).await?;
let mut server_info = BytesMut::with_capacity(25); let mut server_info = BytesMut::new();
let mut backend_id: i32 = 0; let mut backend_id: i32 = 0;
let mut secret_key: i32 = 0; let mut secret_key: i32 = 0;
// We'll be handling multiple packets, but they will all be structured the same.
// We'll loop here until this exchange is complete.
loop { loop {
let code = match stream.read_u8().await { let code = match stream.read_u8().await {
Ok(code) => code as char, Ok(code) => code as char,
@@ -93,16 +95,18 @@ impl Server {
}; };
match code { match code {
// Authentication
'R' => { 'R' => {
// Auth can proceed // Determine which kind of authentication is required, if any.
let code = match stream.read_i32().await { let auth_code = match stream.read_i32().await {
Ok(code) => code, Ok(auth_code) => auth_code,
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError),
}; };
match code { match auth_code {
// MD5 MD5_ENCRYPTED_PASSWORD => {
5 => { // The salt is 4 bytes.
// See: https://www.postgresql.org/docs/12/protocol-message-formats.html
let mut salt = vec![0u8; 4]; let mut salt = vec![0u8; 4];
match stream.read_exact(&mut salt).await { match stream.read_exact(&mut salt).await {
@@ -114,16 +118,16 @@ impl Server {
.await?; .await?;
} }
// Authentication handshake complete. AUTHENTICATION_SUCCESSFUL => (),
0 => (),
_ => { _ => {
println!(">> Unsupported authentication mechanism: {}", code); println!(">> Unsupported authentication mechanism: {}", auth_code);
return Err(Error::ServerError); return Err(Error::ServerError);
} }
} }
} }
// ErrorResponse
'E' => { 'E' => {
let error_code = match stream.read_u8().await { let error_code = match stream.read_u8().await {
Ok(error_code) => error_code, Ok(error_code) => error_code,
@@ -131,46 +135,62 @@ impl Server {
}; };
match error_code { match error_code {
0 => (), // Terminator // No error message is present in the message.
MESSAGE_TERMINATOR => (),
// An error message will be present.
_ => { _ => {
// Read the error message without the terminating null character.
let mut error = vec![0u8; len as usize - 4 - 1]; let mut error = vec![0u8; len as usize - 4 - 1];
match stream.read_exact(&mut error).await { match stream.read_exact(&mut error).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError),
}; };
// TODO: the error message contains multiple fields; we can decode them and
// present a prettier message to the user.
// See: https://www.postgresql.org/docs/12/protocol-error-fields.html
println!(">> Server error: {}", String::from_utf8_lossy(&error)); println!(">> Server error: {}", String::from_utf8_lossy(&error));
} }
}; };
return Err(Error::ServerError); return Err(Error::ServerError);
} }
// ParameterStatus
'S' => { 'S' => {
// Parameter
let mut param = vec![0u8; len as usize - 4]; let mut param = vec![0u8; len as usize - 4];
match stream.read_exact(&mut param).await { match stream.read_exact(&mut param).await {
Ok(_) => (), Ok(_) => (),
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError),
}; };
// Save the parameter so we can pass it to the client later.
// These can be server_encoding, client_encoding, server timezone, Postgres version,
// and many more interesting things we should know about the Postgres server we are talking to.
server_info.put_u8(b'S'); server_info.put_u8(b'S');
server_info.put_i32(len); server_info.put_i32(len);
server_info.put_slice(&param[..]); server_info.put_slice(&param[..]);
} }
// BackendKeyData
'K' => { 'K' => {
// Query cancellation data. // The frontend must save these values if it wishes to be able to issue CancelRequest messages later.
// See: https://www.postgresql.org/docs/12/protocol-message-formats.html
backend_id = match stream.read_i32().await { backend_id = match stream.read_i32().await {
Ok(id) => id, Ok(id) => id,
Err(err) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError),
}; };
secret_key = match stream.read_i32().await { secret_key = match stream.read_i32().await {
Ok(id) => id, Ok(id) => id,
Err(err) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError),
}; };
} }
// ReadyForQuery
'Z' => { 'Z' => {
let mut idle = vec![0u8; len as usize - 4]; let mut idle = vec![0u8; len as usize - 4];
@@ -179,7 +199,8 @@ impl Server {
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError),
}; };
// Startup finished // This is the last step in the client-server connection setup,
// and indicates the server is ready for to query it.
let (read, write) = stream.into_split(); let (read, write) = stream.into_split();
return Ok(Server { return Ok(Server {
@@ -199,6 +220,8 @@ impl Server {
}); });
} }
// We have an unexpected message from the server during this exchange.
// Means we implemented the protocol wrong or we're not talking to a Postgres server.
_ => { _ => {
println!(">> Unknown code: {}", code); println!(">> Unknown code: {}", code);
return Err(Error::ProtocolSyncError); return Err(Error::ProtocolSyncError);
@@ -207,7 +230,7 @@ impl Server {
} }
} }
/// Issue a cancellation request to the server. /// Issue a query cancellation request to the server.
/// Uses a separate connection that's not part of the connection pool. /// Uses a separate connection that's not part of the connection pool.
pub async fn cancel( pub async fn cancel(
host: &str, host: &str,
@@ -225,14 +248,14 @@ impl Server {
let mut bytes = BytesMut::with_capacity(16); let mut bytes = BytesMut::with_capacity(16);
bytes.put_i32(16); bytes.put_i32(16);
bytes.put_i32(80877102); bytes.put_i32(CANCEL_REQUEST_CODE);
bytes.put_i32(process_id); bytes.put_i32(process_id);
bytes.put_i32(secret_key); bytes.put_i32(secret_key);
Ok(write_all(&mut stream, bytes).await?) Ok(write_all(&mut stream, bytes).await?)
} }
/// Send data to the server from the client. /// Send messages to the server from the client.
pub async fn send(&mut self, messages: BytesMut) -> Result<(), Error> { pub async fn send(&mut self, messages: BytesMut) -> Result<(), Error> {
self.stats.data_sent(messages.len()); self.stats.data_sent(messages.len());
@@ -246,7 +269,7 @@ impl Server {
} }
} }
/// Receive data from the server in response to a client request sent previously. /// Receive data from the server in response to a client request.
/// This method must be called multiple times while `self.is_data_available()` is true /// This method must be called multiple times while `self.is_data_available()` is true
/// in order to receive all data the server has to offer. /// in order to receive all data the server has to offer.
pub async fn recv(&mut self) -> Result<BytesMut, Error> { pub async fn recv(&mut self) -> Result<BytesMut, Error> {
@@ -260,79 +283,90 @@ impl Server {
} }
}; };
// Buffer the message we'll forward to the client in a bit. // Buffer the message we'll forward to the client later.
self.buffer.put(&message[..]); self.buffer.put(&message[..]);
let code = message.get_u8() as char; let code = message.get_u8() as char;
let _len = message.get_i32(); let _len = message.get_i32();
match code { match code {
// ReadyForQuery
'Z' => { 'Z' => {
// Ready for query, time to forward buffer to client.
let transaction_state = message.get_u8() as char; let transaction_state = message.get_u8() as char;
match transaction_state { match transaction_state {
// In transaction.
'T' => { 'T' => {
self.in_transaction = true; self.in_transaction = true;
} }
// Idle, transaction over.
'I' => { 'I' => {
self.in_transaction = false; self.in_transaction = false;
} }
// Some error occured, the transaction was rolled back.
'E' => { 'E' => {
self.in_transaction = true; self.in_transaction = true;
} }
// Something totally unexpected, this is not a Postgres server we know.
_ => { _ => {
self.bad = true; self.bad = true;
return Err(Error::ProtocolSyncError); return Err(Error::ProtocolSyncError);
} }
}; };
// There is no more data available from the server.
self.data_available = false; self.data_available = false;
break; break;
} }
// DataRow
'D' => { 'D' => {
// More data is available after this message, this is not the end of the reply.
self.data_available = true; self.data_available = true;
// Don't flush yet, the more we buffer, the faster this goes. // Don't flush yet, the more we buffer, the faster this goes...
// Up to a limit of course. // up to a limit of course.
if self.buffer.len() >= 8196 { if self.buffer.len() >= 8196 {
break; break;
} }
} }
// CopyInResponse: copy is starting from client to server // CopyInResponse: copy is starting from client to server.
'G' => break, 'G' => break,
// CopyOutResponse: copy is starting from the server to the client // CopyOutResponse: copy is starting from the server to the client.
'H' => { 'H' => {
self.data_available = true; self.data_available = true;
break; break;
} }
// CopyData // CopyData: we are not buffering this one because there will be many more
// and we don't know how big this packet could be, best not to take a risk.
'd' => break, 'd' => break,
// CopyDone // CopyDone
'c' => { // Buffer until ReadyForQuery shows up, so don't exit the loop yet.
self.data_available = false; 'c' => (),
// Buffer until ReadyForQuery shows up
}
_ => { // Anything else, e.g. errors, notices, etc.
// Keep buffering, // Keep buffering until ReadyForQuery shows up.
} _ => (),
}; };
} }
let bytes = self.buffer.clone(); let bytes = self.buffer.clone();
// Keep track of how much data we got from the server for stats.
self.stats.data_received(bytes.len()); self.stats.data_received(bytes.len());
// Clear the buffer for next query.
self.buffer.clear(); self.buffer.clear();
// Pass the data back to the client.
Ok(bytes) Ok(bytes)
} }
@@ -381,11 +415,11 @@ impl Server {
} }
/// Execute an arbitrary query against the server. /// Execute an arbitrary query against the server.
/// It will use the Simple query protocol. /// It will use the simple query protocol.
/// Result will not be returned, so this is useful for things like `SET` or `ROLLBACK`. /// Result will not be returned, so this is useful for things like `SET` or `ROLLBACK`.
pub async fn query(&mut self, query: &str) -> Result<(), Error> { pub async fn query(&mut self, query: &str) -> Result<(), Error> {
let mut query = BytesMut::from(&query.as_bytes()[..]); let mut query = BytesMut::from(&query.as_bytes()[..]);
query.put_u8(0); query.put_u8(0); // C-string terminator (NULL character).
let len = query.len() as i32 + 4; let len = query.len() as i32 + 4;
@@ -396,8 +430,10 @@ impl Server {
msg.put_slice(&query[..]); msg.put_slice(&query[..]);
self.send(msg).await?; self.send(msg).await?;
loop { loop {
let _ = self.recv().await?; let _ = self.recv().await?;
if !self.data_available { if !self.data_available {
break; break;
} }
@@ -407,26 +443,31 @@ impl Server {
} }
/// A shorthand for `SET application_name = $1`. /// A shorthand for `SET application_name = $1`.
#[allow(dead_code)]
pub async fn set_name(&mut self, name: &str) -> Result<(), Error> { pub async fn set_name(&mut self, name: &str) -> Result<(), Error> {
Ok(self Ok(self
.query(&format!("SET application_name = '{}'", name)) .query(&format!("SET application_name = '{}'", name))
.await?) .await?)
} }
/// Get the servers address.
#[allow(dead_code)]
pub fn address(&self) -> Address { pub fn address(&self) -> Address {
self.address.clone() self.address.clone()
} }
} }
impl Drop for Server { impl Drop for Server {
// Try to do a clean shut down. /// Try to do a clean shut down. Best effort because
/// the socket is in non-blocking mode, so it may not be ready
/// for a write.
fn drop(&mut self) { fn drop(&mut self) {
let mut bytes = BytesMut::with_capacity(4); let mut bytes = BytesMut::with_capacity(4);
bytes.put_u8(b'X'); bytes.put_u8(b'X');
bytes.put_i32(4); bytes.put_i32(4);
match self.write.try_write(&bytes) { match self.write.try_write(&bytes) {
Ok(n) => (), Ok(_) => (),
Err(_) => (), Err(_) => (),
}; };

View File

@@ -18,7 +18,7 @@ impl Sharder {
let mut lohalf = key as u32; let mut lohalf = key as u32;
let hihalf = (key >> 32) as u32; let hihalf = (key >> 32) as u32;
lohalf ^= if key >= 0 { hihalf } else { !hihalf }; lohalf ^= if key >= 0 { hihalf } else { !hihalf };
Self::combine(0, Self::pg_u32_hash(lohalf)) as usize % self.shards as usize Self::combine(0, Self::pg_u32_hash(lohalf)) as usize % self.shards
} }
#[inline] #[inline]