This commit is contained in:
Mostafa Abdelraouf
2023-03-26 14:52:55 -05:00
parent d66b377a8e
commit 52a980fa0a

View File

@@ -1,10 +1,15 @@
use crate::errors::Error;
use crate::pool::BanReason;
use bb8::PooledConnection;
/// Handle clients by pretending to be a PostgreSQL server.
use bytes::{Buf, BufMut, BytesMut};
use hyper::server::conn;
use log::{debug, error, info, trace, warn};
use crate::pool::ServerPool;
use crate::config::Role;
use std::collections::HashMap;
use std::ops::Add;
use std::time::Instant;
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
@@ -31,6 +36,19 @@ enum ClientConnectionType {
CancelQuery,
}
struct RetryBuffer {
buffer: BytesMut,
retry_count: u32,
}
pub enum ClientFlowControl {
Retry,
PerformNextCommand,
ReleaseConnection,
Disconnect
}
/// The client state. One of these is created per client.
pub struct Client<S, T> {
/// The reads are buffered (8K by default).
@@ -92,6 +110,8 @@ pub struct Client<S, T> {
/// Used to notify clients about an impending shutdown
shutdown: Receiver<()>,
retry_buffer: Option<RetryBuffer>,
}
/// Client entrypoint.
@@ -558,6 +578,7 @@ where
application_name: application_name.to_string(),
shutdown,
connected_to_server: false,
retry_buffer: None,
})
}
@@ -592,155 +613,184 @@ where
application_name: String::from("undefined"),
shutdown,
connected_to_server: false,
retry_buffer: None,
})
}
/// Handle a connected and authenticated client.
pub async fn handle(&mut self) -> Result<(), Error> {
// The client wants to cancel a query it has issued previously.
if self.cancel_mode {
trace!("Sending CancelRequest");
let (process_id, secret_key, address, port) = {
let guard = self.client_server_map.lock();
async fn cancel_query(&mut self) -> Result<(), Error> {
trace!("Sending CancelRequest");
match guard.get(&(self.process_id, self.secret_key)) {
// Drop the mutex as soon as possible.
// We found the server the client is using for its query
// that it wants to cancel.
Some((process_id, secret_key, address, port)) => {
(*process_id, *secret_key, address.clone(), *port)
}
let (process_id, secret_key, address, port) = {
let guard = self.client_server_map.lock();
// The client doesn't know / got the wrong server,
// we're closing the connection for security reasons.
None => return Ok(()),
match guard.get(&(self.process_id, self.secret_key)) {
// Drop the mutex as soon as possible.
// We found the server the client is using for its query
// that it wants to cancel.
Some((process_id, secret_key, address, port)) => {
(*process_id, *secret_key, address.clone(), *port)
}
};
// Opens a new separate connection to the server, sends the backend_id
// and secret_key and then closes it for security reasons. No other interactions
// take place.
return Server::cancel(&address, port, process_id, secret_key).await;
}
// The client doesn't know / got the wrong server,
// we're closing the connection for security reasons.
None => return Ok(()),
}
};
// The query router determines where the query is going to go,
// e.g. primary, replica, which shard.
let mut query_router = QueryRouter::new();
self.stats.client_register(
self.process_id,
self.pool_name.clone(),
self.username.clone(),
self.application_name.clone(),
// Opens a new separate connection to the server, sends the backend_id
// and secret_key and then closes it for security reasons. No other interactions
// take place.
return Server::cancel(&address, port, process_id, secret_key).await;
}
async fn checkout_connection(&mut self, pool: &ConnectionPool, query_router: &mut QueryRouter) -> Result<(PooledConnection<ServerPool>, Address), Error> {
// Grab a server from the pool.
let mut connection = match pool
.get(query_router.shard(), query_router.role(), self.process_id)
.await
{
Ok(conn) => {
debug!("Got connection from pool");
conn
}
Err(err) => {
self.buffer.clear();
error_response(&mut self.write, "could not get connection from the pool")
.await?;
error!("Could not get connection from pool: {{ pool_name: {:?}, username: {:?}, shard: {:?}, role: \"{:?}\", error: \"{:?}\" }}",
self.pool_name.clone(), self.username.clone(), query_router.shard(), query_router.role(), err);
return Err(err);
}
};
let server = &mut connection.0;
let address = connection.1.clone();
// Server is assigned to the client in case the client wants to
// cancel a query later.
server.claim(self.process_id, self.secret_key);
self.connected_to_server = true;
// Update statistics
self.stats
.client_active(self.process_id, server.server_id());
self.last_address_id = Some(address.id);
self.last_server_id = Some(server.server_id());
debug!(
"Client {:?} talking to server {:?}",
self.addr,
server.address()
);
// Our custom protocol loop.
// We expect the client to either start a transaction with regular queries
// or issue commands for our sharding and server selection protocol.
loop {
trace!(
"Client idle, waiting for message, transaction mode: {}",
self.transaction_mode
);
return Ok(connection);
}
// Read a complete message from the client, which normally would be
// either a `Q` (query) or `P` (prepare, extended protocol).
// We can parse it here before grabbing a server from the pool,
// in case the client is sending some custom protocol messages, e.g.
// SET SHARDING KEY TO 'bigint';
async fn client_proc(&mut self, query_router: &mut QueryRouter) -> Result<ClientFlowControl, Error> {
let message = tokio::select! {
_ = self.shutdown.recv() => {
if !self.admin {
error_response_terminal(
&mut self.write,
"terminating connection due to administrator command"
).await?;
return Ok(())
}
trace!(
"Client idle, waiting for message, transaction mode: {}",
self.transaction_mode
);
// Read a complete message from the client, which normally would be
// either a `Q` (query) or `P` (prepare, extended protocol).
// We can parse it here before grabbing a server from the pool,
// in case the client is sending some custom protocol messages, e.g.
// SET SHARDING KEY TO 'bigint';
let message = tokio::select! {
_ = self.shutdown.recv() => {
if !self.admin {
error_response_terminal(
&mut self.write,
"terminating connection due to administrator command"
).await?;
return Ok(ClientFlowControl::Disconnect)
} else {
// Admin clients ignore shutdown.
else {
read_message(&mut self.read).await?
}
},
message_result = read_message(&mut self.read) => message_result?
};
match message[0] as char {
// Buffer extended protocol messages even if we do not have
// a server connection yet. Hopefully, when we get the S message
// we'll be able to allocate a connection. Also, clients do not expect
// the server to respond to these messages so even if we were not able to
// allocate a connection, we wouldn't be able to send back an error message
// to the client so we buffer them and defer the decision to error out or not
// to when we get the S message
'D' | 'E' => {
self.buffer.put(&message[..]);
continue;
read_message(&mut self.read).await?
}
},
'Q' => {
if query_router.query_parser_enabled() {
query_router.infer(&message);
}
}
message_result = read_message(&mut self.read) => message_result?
};
'P' => {
self.buffer.put(&message[..]);
if query_router.query_parser_enabled() {
query_router.infer(&message);
}
continue;
}
'B' => {
self.buffer.put(&message[..]);
if query_router.query_parser_enabled() {
query_router.infer_shard_from_bind(&message);
}
continue;
}
'X' => {
debug!("Client disconnecting");
return Ok(());
}
_ => (),
match message[0] as char {
// Buffer extended protocol messages even if we do not have
// a server connection yet. Hopefully, when we get the S message
// we'll be able to allocate a connection. Also, clients do not expect
// the server to respond to these messages so even if we were not able to
// allocate a connection, we wouldn't be able to send back an error message
// to the client so we buffer them and defer the decision to error out or not
// to when we get the S message
'D' | 'E' => {
self.buffer.put(&message[..]);
return Ok(ClientFlowControl::PerformNextCommand);
}
// Handle admin database queries.
if self.admin {
debug!("Handling admin command");
handle_admin(&mut self.write, message, self.client_server_map.clone()).await?;
continue;
'Q' => {
if query_router.query_parser_enabled() {
query_router.infer(&message);
}
}
// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
let mut pool = self.get_pool().await?;
'P' => {
self.buffer.put(&message[..]);
// Check if the pool is paused and wait until it's resumed.
if pool.wait_paused().await {
// Refresh pool information, something might have changed.
pool = self.get_pool().await?;
if query_router.query_parser_enabled() {
query_router.infer(&message);
}
return Ok(ClientFlowControl::PerformNextCommand);
}
query_router.update_pool_settings(pool.settings.clone());
'B' => {
self.buffer.put(&message[..]);
let current_shard = query_router.shard();
if query_router.query_parser_enabled() {
query_router.infer_shard_from_bind(&message);
}
// Handle all custom protocol commands, if any.
match query_router.try_execute_command(&message) {
// Normal query, not a custom command.
None => (),
return Ok(ClientFlowControl::PerformNextCommand);
}
'X' => {
debug!("Client disconnecting");
return Ok(ClientFlowControl::Disconnect);
}
_ => (),
}
// Handle admin database queries.
if self.admin {
debug!("Handling admin command");
handle_admin(&mut self.write, message, self.client_server_map.clone()).await?;
return Ok(ClientFlowControl::PerformNextCommand);
}
// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
let mut pool = self.get_pool().await?;
// Check if the pool is paused and wait until it's resumed.
if pool.wait_paused().await {
// Refresh pool information, something might have changed.
pool = self.get_pool().await?;
}
query_router.update_pool_settings(pool.settings.clone());
let current_shard = query_router.shard();
// Handle all custom protocol commands, if any.
match query_router.try_execute_command(&message) {
// Normal query, not a custom command.
None => (),
// SET SHARD TO
Some((Command::SetShard, _)) => {
// Selected shard is not configured.
@@ -761,97 +811,48 @@ where
} else {
custom_protocol_response_ok(&mut self.write, "SET SHARD").await?;
}
continue;
return Ok(ClientFlowControl::PerformNextCommand);
}
// SET PRIMARY READS TO
Some((Command::SetPrimaryReads, _)) => {
custom_protocol_response_ok(&mut self.write, "SET PRIMARY READS").await?;
continue;
return Ok(ClientFlowControl::PerformNextCommand);
}
// SET SHARDING KEY TO
Some((Command::SetShardingKey, _)) => {
custom_protocol_response_ok(&mut self.write, "SET SHARDING KEY").await?;
continue;
return Ok(ClientFlowControl::PerformNextCommand);
}
// SET SERVER ROLE TO
Some((Command::SetServerRole, _)) => {
custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?;
continue;
return Ok(ClientFlowControl::PerformNextCommand);
}
// SHOW SERVER ROLE
Some((Command::ShowServerRole, value)) => {
show_response(&mut self.write, "server role", &value).await?;
continue;
return Ok(ClientFlowControl::PerformNextCommand);
}
// SHOW SHARD
Some((Command::ShowShard, value)) => {
show_response(&mut self.write, "shard", &value).await?;
continue;
return Ok(ClientFlowControl::PerformNextCommand);
}
// SHOW PRIMARY READS
Some((Command::ShowPrimaryReads, value)) => {
show_response(&mut self.write, "primary reads", &value).await?;
continue;
return Ok(ClientFlowControl::PerformNextCommand);
}
};
debug!("Waiting for connection from pool");
// Grab a server from the pool.
let connection = match pool
.get(query_router.shard(), query_router.role(), self.process_id)
.await
{
Ok(conn) => {
debug!("Got connection from pool");
conn
}
Err(err) => {
// Client is attempting to get results from the server,
// but we were unable to grab a connection from the pool
// We'll send back an error message and clean the extended
// protocol buffer
if message[0] as char == 'S' {
error!("Got Sync message but failed to get a connection from the pool");
self.buffer.clear();
}
error_response(&mut self.write, "could not get connection from the pool")
.await?;
error!("Could not get connection from pool: {{ pool_name: {:?}, username: {:?}, shard: {:?}, role: \"{:?}\", error: \"{:?}\" }}",
self.pool_name.clone(), self.username.clone(), query_router.shard(), query_router.role(), err);
continue;
}
};
let mut reference = connection.0;
let address = connection.1;
let server = &mut *reference;
// Server is assigned to the client in case the client wants to
// cancel a query later.
server.claim(self.process_id, self.secret_key);
self.connected_to_server = true;
// Update statistics
self.stats
.client_active(self.process_id, server.server_id());
self.last_address_id = Some(address.id);
self.last_server_id = Some(server.server_id());
debug!(
"Client {:?} talking to server {:?}",
self.addr,
server.address()
);
// TODO: investigate other parameters and set them too.
// Set application_name.
@@ -866,178 +867,19 @@ where
// If the client is in session mode, no more custom protocol
// commands will be accepted.
loop {
let message = match initial_message {
None => {
trace!("Waiting for message inside transaction or in session mode");
match read_message(&mut self.read).await {
Ok(message) => message,
Err(err) => {
// Client disconnected inside a transaction.
// Clean up the server and re-use it.
server.checkin_cleanup().await?;
return Err(err);
match self.tx_proc(server, &pool, initial_message).await {
Ok(control_flow) => {
match control_flow {
ClientFlowControl::PerformNextCommand => {
initial_message = None;
continue;
}
control_flow_result => {
return Ok(control_flow_result);
}
}
}
Some(message) => {
initial_message = None;
message
}
};
// The message will be forwarded to the server intact. We still would like to
// parse it below to figure out what to do with it.
// Safe to unwrap because we know this message has a certain length and has the code
// This reads the first byte without advancing the internal pointer and mutating the bytes
let code = *message.get(0).unwrap() as char;
trace!("Message: {}", code);
match code {
// Query
'Q' => {
debug!("Sending query to server");
self.send_and_receive_loop(code, Some(&message), server, &address, &pool)
.await?;
if !server.in_transaction() {
// Report transaction executed statistics.
self.stats.transaction(self.process_id, server.server_id());
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
break;
}
}
}
// Terminate
'X' => {
server.checkin_cleanup().await?;
self.release();
return Ok(());
}
// Parse
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
'P' => {
self.buffer.put(&message[..]);
}
// Bind
// The placeholder's replacements are here, e.g. 'user@email.com' and 'true'
'B' => {
self.buffer.put(&message[..]);
}
// Describe
// Command a client can issue to describe a previously prepared named statement.
'D' => {
self.buffer.put(&message[..]);
}
// Execute
// Execute a prepared statement prepared in `P` and bound in `B`.
'E' => {
self.buffer.put(&message[..]);
}
// Sync
// Frontend (client) is asking for the query result now.
'S' => {
debug!("Sending query to server");
self.buffer.put(&message[..]);
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
// Almost certainly true
if first_message_code == 'P' {
// Message layout
// P followed by 32 int followed by null-terminated statement name
// So message code should be in offset 0 of the buffer, first character
// in prepared statement name would be index 5
let first_char_in_name = *self.buffer.get(5).unwrap_or(&0);
if first_char_in_name != 0 {
// This is a named prepared statement
// Server connection state will need to be cleared at checkin
server.mark_dirty();
}
}
self.send_and_receive_loop(code, None, server, &address, &pool)
.await?;
self.buffer.clear();
if !server.in_transaction() {
self.stats.transaction(self.process_id, server.server_id());
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
break;
}
}
}
// CopyData
'd' => {
self.buffer.put(&message[..]);
// Want to limit buffer size
if self.buffer.len() > 8196 {
// Forward the data to the server,
self.send_server_message(server, &self.buffer, &address, &pool)
.await?;
self.buffer.clear();
}
}
// CopyDone or CopyFail
// Copy is done, successfully or not.
'c' | 'f' => {
// We may already have some copy data in the buffer, add this message to buffer
self.buffer.put(&message[..]);
self.send_server_message(server, &self.buffer, &address, &pool)
.await?;
// Clear the buffer
self.buffer.clear();
let response = self.receive_server_message(server, &address, &pool).await?;
match write_all_half(&mut self.write, &response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
return Err(err);
}
};
if !server.in_transaction() {
self.stats.transaction(self.process_id, server.server_id());
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
break;
}
}
}
// Some unexpected message. We either did not implement the protocol correctly
// or this is not a Postgres client we're talking to.
_ => {
error!("Unexpected code: {}", code);
}
},
Err(err) => return Err(err),
}
}
@@ -1050,6 +892,249 @@ where
self.release();
self.stats.client_idle(self.process_id);
}
#[inline(always)]
pub async fn tx_proc(&mut self, server: &mut Server, pool: &ConnectionPool, initial_message: Option<BytesMut> ) -> Result<ClientFlowControl, Error> {
let message = match initial_message {
None => {
trace!("Waiting for message inside transaction or in session mode");
match read_message(&mut self.read).await {
Ok(message) => message,
Err(err) => {
// Client disconnected inside a transaction.
// Clean up the server and re-use it.
server.checkin_cleanup().await?;
return Err(err);
}
}
}
Some(message) => message
};
// The message will be forwarded to the server intact. We still would like to
// parse it below to figure out what to do with it.
// Safe to unwrap because we know this message has a certain length and has the code
// This reads the first byte without advancing the internal pointer and mutating the bytes
let code = *message.get(0).unwrap() as char;
let address = server.address();
trace!("Message: {}", code);
match code {
// Query
'Q' => {
debug!("Sending query to server");
match self.send_and_receive_loop(code, Some(&message), server, &address, &pool).await {
Ok(_) => self.retry_buffer = None,
Err(_) => {
if server.is_bad() && !server.in_transaction() && server.address().role == Role::Replica {
match self.retry_buffer {
Some(ref mut retry_buffer) => {
if retry_buffer.retry_count < 3 {
retry_buffer.retry_count += 1;
return Ok(ClientFlowControl::Retry);
}
},
None => {
self.retry_buffer = Some(RetryBuffer { buffer: message, retry_count: 0 });
return Ok(ClientFlowControl::Retry);
}
}
}
},
}
if !server.in_transaction() {
// Report transaction executed statistics.
self.stats.transaction(self.process_id, server.server_id());
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
return Ok(ClientFlowControl::ReleaseConnection);
}
}
}
// Terminate
'X' => {
server.checkin_cleanup().await?;
self.release();
return Ok(ClientFlowControl::Disconnect);
}
// Parse
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
'P' => {
self.buffer.put(&message[..]);
}
// Bind
// The placeholder's replacements are here, e.g. 'user@email.com' and 'true'
'B' => {
self.buffer.put(&message[..]);
}
// Describe
// Command a client can issue to describe a previously prepared named statement.
'D' => {
self.buffer.put(&message[..]);
}
// Execute
// Execute a prepared statement prepared in `P` and bound in `B`.
'E' => {
self.buffer.put(&message[..]);
}
// Sync
// Frontend (client) is asking for the query result now.
'S' => {
debug!("Sending query to server");
self.buffer.put(&message[..]);
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
// Almost certainly true
if first_message_code == 'P' {
// Message layout
// P followed by 32 int followed by null-terminated statement name
// So message code should be in offset 0 of the buffer, first character
// in prepared statement name would be index 5
let first_char_in_name = *self.buffer.get(5).unwrap_or(&0);
if first_char_in_name != 0 {
// This is a named prepared statement
// Server connection state will need to be cleared at checkin
server.mark_dirty();
}
}
match self.send_and_receive_loop(code, None, server, &address, &pool).await {
Ok(_) => self.retry_buffer = None,
Err(err) => {
if server.is_bad() && !server.in_transaction() && server.address().role == Role::Replica {
match self.retry_buffer {
Some(ref mut retry_buffer) => {
if retry_buffer.retry_count < 3 {
retry_buffer.retry_count += 1;
return Ok(ClientFlowControl::Retry);
}
self.retry_buffer = None;
return Err(err);
},
None => {
let buffer = self.buffer.clone();
self.buffer.clear();
self.retry_buffer = Some(RetryBuffer { buffer: message, retry_count: 0 });
return Ok(ClientFlowControl::Retry);
}
}
}
}
}
self.buffer.clear();
if !server.in_transaction() {
self.stats.transaction(self.process_id, server.server_id());
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
return Ok(ClientFlowControl::ReleaseConnection);
}
}
}
// CopyData
'd' => {
self.buffer.put(&message[..]);
// Want to limit buffer size
if self.buffer.len() > 8196 {
// Forward the data to the server,
self.send_server_message(server, &self.buffer, &address, &pool)
.await?;
self.buffer.clear();
}
}
// CopyDone or CopyFail
// Copy is done, successfully or not.
'c' | 'f' => {
// We may already have some copy data in the buffer, add this message to buffer
self.buffer.put(&message[..]);
self.send_server_message(server, &self.buffer, &address, &pool)
.await?;
// Clear the buffer
self.buffer.clear();
let response = self.receive_server_message(server, &address, &pool).await?;
match write_all_half(&mut self.write, &response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
return Err(err);
}
};
if !server.in_transaction() {
self.stats.transaction(self.process_id, server.server_id());
// Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode {
return Ok(ClientFlowControl::ReleaseConnection);
}
}
}
// Some unexpected message. We either did not implement the protocol correctly
// or this is not a Postgres client we're talking to.
_ => {
return Err(Error::ProtocolSyncError("bad message code".to_string()));
}
}
return Ok(ClientFlowControl::PerformNextCommand);
}
/// Handle a connected and authenticated client.
pub async fn handle(&mut self) -> Result<(), Error> {
// The client wants to cancel a query it has issued previously.
if self.cancel_mode {
return self.cancel_query().await;
}
// The query router determines where the query is going to go,
// e.g. primary, replica, which shard.
let mut query_router = QueryRouter::new();
self.stats.client_register(
self.process_id,
self.pool_name.clone(),
self.username.clone(),
self.application_name.clone(),
);
// Our custom protocol loop.
// We expect the client to either start a transaction with regular queries
// or issue commands for our sharding and server selection protocol.
loop {
self.client_proc(&mut query_router).await?;
}
}
/// Retrieve connection pool, if it exists.