Compare commits

...

1 Commits

Author SHA1 Message Date
Mostafa Abdelraouf
52a980fa0a wip 2023-03-26 14:52:55 -05:00

View File

@@ -1,10 +1,15 @@
use crate::errors::Error; use crate::errors::Error;
use crate::pool::BanReason; use crate::pool::BanReason;
use bb8::PooledConnection;
/// Handle clients by pretending to be a PostgreSQL server. /// Handle clients by pretending to be a PostgreSQL server.
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use hyper::server::conn;
use log::{debug, error, info, trace, warn}; use log::{debug, error, info, trace, warn};
use crate::pool::ServerPool;
use crate::config::Role;
use std::collections::HashMap; use std::collections::HashMap;
use std::ops::Add;
use std::time::Instant; use std::time::Instant;
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf}; use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
@@ -31,6 +36,19 @@ enum ClientConnectionType {
CancelQuery, CancelQuery,
} }
struct RetryBuffer {
buffer: BytesMut,
retry_count: u32,
}
pub enum ClientFlowControl {
Retry,
PerformNextCommand,
ReleaseConnection,
Disconnect
}
/// The client state. One of these is created per client. /// The client state. One of these is created per client.
pub struct Client<S, T> { pub struct Client<S, T> {
/// The reads are buffered (8K by default). /// The reads are buffered (8K by default).
@@ -92,6 +110,8 @@ pub struct Client<S, T> {
/// Used to notify clients about an impending shutdown /// Used to notify clients about an impending shutdown
shutdown: Receiver<()>, shutdown: Receiver<()>,
retry_buffer: Option<RetryBuffer>,
} }
/// Client entrypoint. /// Client entrypoint.
@@ -558,6 +578,7 @@ where
application_name: application_name.to_string(), application_name: application_name.to_string(),
shutdown, shutdown,
connected_to_server: false, connected_to_server: false,
retry_buffer: None,
}) })
} }
@@ -592,13 +613,12 @@ where
application_name: String::from("undefined"), application_name: String::from("undefined"),
shutdown, shutdown,
connected_to_server: false, connected_to_server: false,
retry_buffer: None,
}) })
} }
/// Handle a connected and authenticated client.
pub async fn handle(&mut self) -> Result<(), Error> { async fn cancel_query(&mut self) -> Result<(), Error> {
// The client wants to cancel a query it has issued previously.
if self.cancel_mode {
trace!("Sending CancelRequest"); trace!("Sending CancelRequest");
let (process_id, secret_key, address, port) = { let (process_id, secret_key, address, port) = {
@@ -624,20 +644,53 @@ where
return Server::cancel(&address, port, process_id, secret_key).await; return Server::cancel(&address, port, process_id, secret_key).await;
} }
// The query router determines where the query is going to go, async fn checkout_connection(&mut self, pool: &ConnectionPool, query_router: &mut QueryRouter) -> Result<(PooledConnection<ServerPool>, Address), Error> {
// e.g. primary, replica, which shard. // Grab a server from the pool.
let mut query_router = QueryRouter::new(); let mut connection = match pool
self.stats.client_register( .get(query_router.shard(), query_router.role(), self.process_id)
self.process_id, .await
self.pool_name.clone(), {
self.username.clone(), Ok(conn) => {
self.application_name.clone(), debug!("Got connection from pool");
conn
}
Err(err) => {
self.buffer.clear();
error_response(&mut self.write, "could not get connection from the pool")
.await?;
error!("Could not get connection from pool: {{ pool_name: {:?}, username: {:?}, shard: {:?}, role: \"{:?}\", error: \"{:?}\" }}",
self.pool_name.clone(), self.username.clone(), query_router.shard(), query_router.role(), err);
return Err(err);
}
};
let server = &mut connection.0;
let address = connection.1.clone();
// Server is assigned to the client in case the client wants to
// cancel a query later.
server.claim(self.process_id, self.secret_key);
self.connected_to_server = true;
// Update statistics
self.stats
.client_active(self.process_id, server.server_id());
self.last_address_id = Some(address.id);
self.last_server_id = Some(server.server_id());
debug!(
"Client {:?} talking to server {:?}",
self.addr,
server.address()
); );
// Our custom protocol loop. return Ok(connection);
// We expect the client to either start a transaction with regular queries }
// or issue commands for our sharding and server selection protocol.
loop { async fn client_proc(&mut self, query_router: &mut QueryRouter) -> Result<ClientFlowControl, Error> {
trace!( trace!(
"Client idle, waiting for message, transaction mode: {}", "Client idle, waiting for message, transaction mode: {}",
self.transaction_mode self.transaction_mode
@@ -648,7 +701,6 @@ where
// We can parse it here before grabbing a server from the pool, // We can parse it here before grabbing a server from the pool,
// in case the client is sending some custom protocol messages, e.g. // in case the client is sending some custom protocol messages, e.g.
// SET SHARDING KEY TO 'bigint'; // SET SHARDING KEY TO 'bigint';
let message = tokio::select! { let message = tokio::select! {
_ = self.shutdown.recv() => { _ = self.shutdown.recv() => {
if !self.admin { if !self.admin {
@@ -656,14 +708,13 @@ where
&mut self.write, &mut self.write,
"terminating connection due to administrator command" "terminating connection due to administrator command"
).await?; ).await?;
return Ok(()) return Ok(ClientFlowControl::Disconnect)
} } else {
// Admin clients ignore shutdown. // Admin clients ignore shutdown.
else {
read_message(&mut self.read).await? read_message(&mut self.read).await?
} }
}, },
message_result = read_message(&mut self.read) => message_result? message_result = read_message(&mut self.read) => message_result?
}; };
@@ -677,7 +728,7 @@ where
// to when we get the S message // to when we get the S message
'D' | 'E' => { 'D' | 'E' => {
self.buffer.put(&message[..]); self.buffer.put(&message[..]);
continue; return Ok(ClientFlowControl::PerformNextCommand);
} }
'Q' => { 'Q' => {
@@ -693,7 +744,7 @@ where
query_router.infer(&message); query_router.infer(&message);
} }
continue; return Ok(ClientFlowControl::PerformNextCommand);
} }
'B' => { 'B' => {
@@ -703,12 +754,12 @@ where
query_router.infer_shard_from_bind(&message); query_router.infer_shard_from_bind(&message);
} }
continue; return Ok(ClientFlowControl::PerformNextCommand);
} }
'X' => { 'X' => {
debug!("Client disconnecting"); debug!("Client disconnecting");
return Ok(()); return Ok(ClientFlowControl::Disconnect);
} }
_ => (), _ => (),
@@ -718,7 +769,7 @@ where
if self.admin { if self.admin {
debug!("Handling admin command"); debug!("Handling admin command");
handle_admin(&mut self.write, message, self.client_server_map.clone()).await?; handle_admin(&mut self.write, message, self.client_server_map.clone()).await?;
continue; return Ok(ClientFlowControl::PerformNextCommand);
} }
// Get a pool instance referenced by the most up-to-date // Get a pool instance referenced by the most up-to-date
@@ -740,7 +791,6 @@ where
match query_router.try_execute_command(&message) { match query_router.try_execute_command(&message) {
// Normal query, not a custom command. // Normal query, not a custom command.
None => (), None => (),
// SET SHARD TO // SET SHARD TO
Some((Command::SetShard, _)) => { Some((Command::SetShard, _)) => {
// Selected shard is not configured. // Selected shard is not configured.
@@ -761,97 +811,48 @@ where
} else { } else {
custom_protocol_response_ok(&mut self.write, "SET SHARD").await?; custom_protocol_response_ok(&mut self.write, "SET SHARD").await?;
} }
continue; return Ok(ClientFlowControl::PerformNextCommand);
} }
// SET PRIMARY READS TO // SET PRIMARY READS TO
Some((Command::SetPrimaryReads, _)) => { Some((Command::SetPrimaryReads, _)) => {
custom_protocol_response_ok(&mut self.write, "SET PRIMARY READS").await?; custom_protocol_response_ok(&mut self.write, "SET PRIMARY READS").await?;
continue; return Ok(ClientFlowControl::PerformNextCommand);
} }
// SET SHARDING KEY TO // SET SHARDING KEY TO
Some((Command::SetShardingKey, _)) => { Some((Command::SetShardingKey, _)) => {
custom_protocol_response_ok(&mut self.write, "SET SHARDING KEY").await?; custom_protocol_response_ok(&mut self.write, "SET SHARDING KEY").await?;
continue; return Ok(ClientFlowControl::PerformNextCommand);
} }
// SET SERVER ROLE TO // SET SERVER ROLE TO
Some((Command::SetServerRole, _)) => { Some((Command::SetServerRole, _)) => {
custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?; custom_protocol_response_ok(&mut self.write, "SET SERVER ROLE").await?;
continue; return Ok(ClientFlowControl::PerformNextCommand);
} }
// SHOW SERVER ROLE // SHOW SERVER ROLE
Some((Command::ShowServerRole, value)) => { Some((Command::ShowServerRole, value)) => {
show_response(&mut self.write, "server role", &value).await?; show_response(&mut self.write, "server role", &value).await?;
continue; return Ok(ClientFlowControl::PerformNextCommand);
} }
// SHOW SHARD // SHOW SHARD
Some((Command::ShowShard, value)) => { Some((Command::ShowShard, value)) => {
show_response(&mut self.write, "shard", &value).await?; show_response(&mut self.write, "shard", &value).await?;
continue; return Ok(ClientFlowControl::PerformNextCommand);
} }
// SHOW PRIMARY READS // SHOW PRIMARY READS
Some((Command::ShowPrimaryReads, value)) => { Some((Command::ShowPrimaryReads, value)) => {
show_response(&mut self.write, "primary reads", &value).await?; show_response(&mut self.write, "primary reads", &value).await?;
continue; return Ok(ClientFlowControl::PerformNextCommand);
} }
}; };
debug!("Waiting for connection from pool"); debug!("Waiting for connection from pool");
// Grab a server from the pool.
let connection = match pool
.get(query_router.shard(), query_router.role(), self.process_id)
.await
{
Ok(conn) => {
debug!("Got connection from pool");
conn
}
Err(err) => {
// Client is attempting to get results from the server,
// but we were unable to grab a connection from the pool
// We'll send back an error message and clean the extended
// protocol buffer
if message[0] as char == 'S' {
error!("Got Sync message but failed to get a connection from the pool");
self.buffer.clear();
}
error_response(&mut self.write, "could not get connection from the pool")
.await?;
error!("Could not get connection from pool: {{ pool_name: {:?}, username: {:?}, shard: {:?}, role: \"{:?}\", error: \"{:?}\" }}",
self.pool_name.clone(), self.username.clone(), query_router.shard(), query_router.role(), err);
continue;
}
};
let mut reference = connection.0;
let address = connection.1;
let server = &mut *reference;
// Server is assigned to the client in case the client wants to
// cancel a query later.
server.claim(self.process_id, self.secret_key);
self.connected_to_server = true;
// Update statistics
self.stats
.client_active(self.process_id, server.server_id());
self.last_address_id = Some(address.id);
self.last_server_id = Some(server.server_id());
debug!(
"Client {:?} talking to server {:?}",
self.addr,
server.address()
);
// TODO: investigate other parameters and set them too. // TODO: investigate other parameters and set them too.
// Set application_name. // Set application_name.
@@ -866,6 +867,35 @@ where
// If the client is in session mode, no more custom protocol // If the client is in session mode, no more custom protocol
// commands will be accepted. // commands will be accepted.
loop { loop {
match self.tx_proc(server, &pool, initial_message).await {
Ok(control_flow) => {
match control_flow {
ClientFlowControl::PerformNextCommand => {
initial_message = None;
continue;
}
control_flow_result => {
return Ok(control_flow_result);
}
}
},
Err(err) => return Err(err),
}
}
// The server is no longer bound to us, we can't cancel it's queries anymore.
debug!("Releasing server back into the pool");
server.checkin_cleanup().await?;
self.stats.server_idle(server.server_id());
self.connected_to_server = false;
self.release();
self.stats.client_idle(self.process_id);
}
#[inline(always)]
pub async fn tx_proc(&mut self, server: &mut Server, pool: &ConnectionPool, initial_message: Option<BytesMut> ) -> Result<ClientFlowControl, Error> {
let message = match initial_message { let message = match initial_message {
None => { None => {
trace!("Waiting for message inside transaction or in session mode"); trace!("Waiting for message inside transaction or in session mode");
@@ -881,10 +911,7 @@ where
} }
} }
} }
Some(message) => { Some(message) => message
initial_message = None;
message
}
}; };
// The message will be forwarded to the server intact. We still would like to // The message will be forwarded to the server intact. We still would like to
@@ -893,7 +920,7 @@ where
// Safe to unwrap because we know this message has a certain length and has the code // Safe to unwrap because we know this message has a certain length and has the code
// This reads the first byte without advancing the internal pointer and mutating the bytes // This reads the first byte without advancing the internal pointer and mutating the bytes
let code = *message.get(0).unwrap() as char; let code = *message.get(0).unwrap() as char;
let address = server.address();
trace!("Message: {}", code); trace!("Message: {}", code);
match code { match code {
@@ -901,8 +928,25 @@ where
'Q' => { 'Q' => {
debug!("Sending query to server"); debug!("Sending query to server");
self.send_and_receive_loop(code, Some(&message), server, &address, &pool) match self.send_and_receive_loop(code, Some(&message), server, &address, &pool).await {
.await?; Ok(_) => self.retry_buffer = None,
Err(_) => {
if server.is_bad() && !server.in_transaction() && server.address().role == Role::Replica {
match self.retry_buffer {
Some(ref mut retry_buffer) => {
if retry_buffer.retry_count < 3 {
retry_buffer.retry_count += 1;
return Ok(ClientFlowControl::Retry);
}
},
None => {
self.retry_buffer = Some(RetryBuffer { buffer: message, retry_count: 0 });
return Ok(ClientFlowControl::Retry);
}
}
}
},
}
if !server.in_transaction() { if !server.in_transaction() {
// Report transaction executed statistics. // Report transaction executed statistics.
@@ -911,7 +955,7 @@ where
// Release server back to the pool if we are in transaction mode. // Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects. // If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode { if self.transaction_mode {
break; return Ok(ClientFlowControl::ReleaseConnection);
} }
} }
} }
@@ -921,7 +965,7 @@ where
server.checkin_cleanup().await?; server.checkin_cleanup().await?;
self.release(); self.release();
return Ok(()); return Ok(ClientFlowControl::Disconnect);
} }
// Parse // Parse
@@ -971,8 +1015,29 @@ where
} }
} }
self.send_and_receive_loop(code, None, server, &address, &pool) match self.send_and_receive_loop(code, None, server, &address, &pool).await {
.await?; Ok(_) => self.retry_buffer = None,
Err(err) => {
if server.is_bad() && !server.in_transaction() && server.address().role == Role::Replica {
match self.retry_buffer {
Some(ref mut retry_buffer) => {
if retry_buffer.retry_count < 3 {
retry_buffer.retry_count += 1;
return Ok(ClientFlowControl::Retry);
}
self.retry_buffer = None;
return Err(err);
},
None => {
let buffer = self.buffer.clone();
self.buffer.clear();
self.retry_buffer = Some(RetryBuffer { buffer: message, retry_count: 0 });
return Ok(ClientFlowControl::Retry);
}
}
}
}
}
self.buffer.clear(); self.buffer.clear();
@@ -982,7 +1047,7 @@ where
// Release server back to the pool if we are in transaction mode. // Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects. // If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode { if self.transaction_mode {
break; return Ok(ClientFlowControl::ReleaseConnection);
} }
} }
} }
@@ -1028,7 +1093,7 @@ where
// Release server back to the pool if we are in transaction mode. // Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects. // If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode { if self.transaction_mode {
break; return Ok(ClientFlowControl::ReleaseConnection);
} }
} }
} }
@@ -1036,19 +1101,39 @@ where
// Some unexpected message. We either did not implement the protocol correctly // Some unexpected message. We either did not implement the protocol correctly
// or this is not a Postgres client we're talking to. // or this is not a Postgres client we're talking to.
_ => { _ => {
error!("Unexpected code: {}", code); return Err(Error::ProtocolSyncError("bad message code".to_string()));
} }
} }
return Ok(ClientFlowControl::PerformNextCommand);
} }
// The server is no longer bound to us, we can't cancel it's queries anymore.
debug!("Releasing server back into the pool");
server.checkin_cleanup().await?;
self.stats.server_idle(server.server_id());
self.connected_to_server = false;
self.release();
self.stats.client_idle(self.process_id); /// Handle a connected and authenticated client.
pub async fn handle(&mut self) -> Result<(), Error> {
// The client wants to cancel a query it has issued previously.
if self.cancel_mode {
return self.cancel_query().await;
}
// The query router determines where the query is going to go,
// e.g. primary, replica, which shard.
let mut query_router = QueryRouter::new();
self.stats.client_register(
self.process_id,
self.pool_name.clone(),
self.username.clone(),
self.application_name.clone(),
);
// Our custom protocol loop.
// We expect the client to either start a transaction with regular queries
// or issue commands for our sharding and server selection protocol.
loop {
self.client_proc(&mut query_router).await?;
} }
} }