Support for prepared statements (#474)

* Start prepared statements

* parse

* Ok

* optional

* dont rewrite anonymous prepared stmts

* Dont rewrite anonymous prep statements

* hm?

* prep statements

* I see!

* comment

* Print config value

* Rewrite bind and add sqlx test

* fmt

* ok

* Fix

* Fix stats

* its late

* clean up PREPARE
This commit is contained in:
Lev Kokotov
2023-06-16 12:57:44 -07:00
committed by GitHub
parent 94c781881f
commit c7d6273037
14 changed files with 1954 additions and 10 deletions

View File

@@ -3,8 +3,9 @@ use crate::pool::BanReason;
/// Handle clients by pretending to be a PostgreSQL server.
use bytes::{Buf, BufMut, BytesMut};
use log::{debug, error, info, trace, warn};
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::{atomic::AtomicUsize, Arc};
use std::time::Instant;
use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
@@ -13,7 +14,9 @@ use tokio::sync::mpsc::Sender;
use crate::admin::{generate_server_info_for_admin, handle_admin};
use crate::auth_passthrough::refetch_auth_hash;
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
use crate::config::{
get_config, get_idle_client_in_transaction_timeout, get_prepared_statements, Address, PoolMode,
};
use crate::constants::*;
use crate::messages::*;
use crate::plugins::PluginOutput;
@@ -25,6 +28,11 @@ use crate::tls::Tls;
use tokio_rustls::server::TlsStream;
/// Incrementally count prepared statements
/// to avoid random conflicts in places where the random number generator is weak.
pub static PREPARED_STATEMENT_COUNTER: Lazy<Arc<AtomicUsize>> =
Lazy::new(|| Arc::new(AtomicUsize::new(0)));
/// Type of connection received from client.
enum ClientConnectionType {
Startup,
@@ -93,6 +101,9 @@ pub struct Client<S, T> {
/// Used to notify clients about an impending shutdown
shutdown: Receiver<()>,
/// Prepared statements
prepared_statements: HashMap<String, Parse>,
}
/// Client entrypoint.
@@ -682,6 +693,7 @@ where
application_name: application_name.to_string(),
shutdown,
connected_to_server: false,
prepared_statements: HashMap::new(),
})
}
@@ -716,6 +728,7 @@ where
application_name: String::from("undefined"),
shutdown,
connected_to_server: false,
prepared_statements: HashMap::new(),
})
}
@@ -757,6 +770,10 @@ where
// Result returned by one of the plugins.
let mut plugin_output = None;
// Prepared statement being executed
let mut prepared_statement = None;
let mut will_prepare = false;
// Our custom protocol loop.
// We expect the client to either start a transaction with regular queries
// or issue commands for our sharding and server selection protocol.
@@ -766,13 +783,16 @@ where
self.transaction_mode
);
// Should we rewrite prepared statements and bind messages?
let mut prepared_statements_enabled = get_prepared_statements();
// Read a complete message from the client, which normally would be
// either a `Q` (query) or `P` (prepare, extended protocol).
// We can parse it here before grabbing a server from the pool,
// in case the client is sending some custom protocol messages, e.g.
// SET SHARDING KEY TO 'bigint';
let message = tokio::select! {
let mut message = tokio::select! {
_ = self.shutdown.recv() => {
if !self.admin {
error_response_terminal(
@@ -800,7 +820,21 @@ where
// allocate a connection, we wouldn't be able to send back an error message
// to the client so we buffer them and defer the decision to error out or not
// to when we get the S message
'D' | 'E' => {
'D' => {
if prepared_statements_enabled {
let name;
(name, message) = self.rewrite_describe(message).await?;
if let Some(name) = name {
prepared_statement = Some(name);
}
}
self.buffer.put(&message[..]);
continue;
}
'E' => {
self.buffer.put(&message[..]);
continue;
}
@@ -830,6 +864,11 @@ where
}
'P' => {
if prepared_statements_enabled {
(prepared_statement, message) = self.rewrite_parse(message)?;
will_prepare = true;
}
self.buffer.put(&message[..]);
if query_router.query_parser_enabled() {
@@ -846,6 +885,10 @@ where
}
'B' => {
if prepared_statements_enabled {
(prepared_statement, message) = self.rewrite_bind(message).await?;
}
self.buffer.put(&message[..]);
if query_router.query_parser_enabled() {
@@ -1054,7 +1097,48 @@ where
// If the client is in session mode, no more custom protocol
// commands will be accepted.
loop {
let message = match initial_message {
// Only check if we should rewrite prepared statements
// in session mode. In transaction mode, we check at the beginning of
// each transaction.
if !self.transaction_mode {
prepared_statements_enabled = get_prepared_statements();
}
debug!("Prepared statement active: {:?}", prepared_statement);
// We are processing a prepared statement.
if let Some(ref name) = prepared_statement {
debug!("Checking prepared statement is on server");
// Get the prepared statement the server expects to see.
let statement = match self.prepared_statements.get(name) {
Some(statement) => {
debug!("Prepared statement `{}` found in cache", name);
statement
}
None => {
return Err(Error::ClientError(format!(
"prepared statement `{}` not found",
name
)))
}
};
// Since it's already in the buffer, we don't need to prepare it on this server.
if will_prepare {
server.will_prepare(&statement.name);
will_prepare = false;
} else {
// The statement is not prepared on the server, so we need to prepare it.
if server.should_prepare(&statement.name) {
server.prepare(statement).await?;
}
}
// Done processing the prepared statement.
prepared_statement = None;
}
let mut message = match initial_message {
None => {
trace!("Waiting for message inside transaction or in session mode");
@@ -1173,6 +1257,11 @@ where
// Parse
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
'P' => {
if prepared_statements_enabled {
(prepared_statement, message) = self.rewrite_parse(message)?;
will_prepare = true;
}
if query_router.query_parser_enabled() {
if let Ok(ast) = QueryRouter::parse(&message) {
if let Ok(output) = query_router.execute_plugins(&ast).await {
@@ -1187,12 +1276,25 @@ where
// Bind
// The placeholder's replacements are here, e.g. 'user@email.com' and 'true'
'B' => {
if prepared_statements_enabled {
(prepared_statement, message) = self.rewrite_bind(message).await?;
}
self.buffer.put(&message[..]);
}
// Describe
// Command a client can issue to describe a previously prepared named statement.
'D' => {
if prepared_statements_enabled {
let name;
(name, message) = self.rewrite_describe(message).await?;
if let Some(name) = name {
prepared_statement = Some(name);
}
}
self.buffer.put(&message[..]);
}
@@ -1235,7 +1337,7 @@ where
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
// Almost certainly true
if first_message_code == 'P' {
if first_message_code == 'P' && !prepared_statements_enabled {
// Message layout
// P followed by 32 int followed by null-terminated statement name
// So message code should be in offset 0 of the buffer, first character
@@ -1363,6 +1465,107 @@ where
}
}
/// Rewrite Parse (F) message to set the prepared statement name to one we control.
/// Save it into the client cache.
fn rewrite_parse(&mut self, message: BytesMut) -> Result<(Option<String>, BytesMut), Error> {
let parse: Parse = (&message).try_into()?;
let name = parse.name.clone();
// Don't rewrite anonymous prepared statements
if parse.anonymous() {
debug!("Anonymous prepared statement");
return Ok((None, message));
}
let parse = parse.rename();
debug!(
"Renamed prepared statement `{}` to `{}` and saved to cache",
name, parse.name
);
self.prepared_statements.insert(name.clone(), parse.clone());
Ok((Some(name), parse.try_into()?))
}
/// Rewrite the Bind (F) message to use the prepared statement name
/// saved in the client cache.
async fn rewrite_bind(
&mut self,
message: BytesMut,
) -> Result<(Option<String>, BytesMut), Error> {
let bind: Bind = (&message).try_into()?;
let name = bind.prepared_statement.clone();
if bind.anonymous() {
debug!("Anonymous bind message");
return Ok((None, message));
}
match self.prepared_statements.get(&name) {
Some(prepared_stmt) => {
let bind = bind.reassign(prepared_stmt);
debug!("Rewrote bind `{}` to `{}`", name, bind.prepared_statement);
Ok((Some(name), bind.try_into()?))
}
None => {
debug!("Got bind for unknown prepared statement {:?}", bind);
error_response(
&mut self.write,
&format!(
"prepared statement \"{}\" does not exist",
bind.prepared_statement
),
)
.await?;
Err(Error::ClientError(format!(
"Prepared statement `{}` doesn't exist",
name
)))
}
}
}
/// Rewrite the Describe (F) message to use the prepared statement name
/// saved in the client cache.
async fn rewrite_describe(
&mut self,
message: BytesMut,
) -> Result<(Option<String>, BytesMut), Error> {
let describe: Describe = (&message).try_into()?;
let name = describe.statement_name.clone();
if describe.anonymous() {
debug!("Anonymous describe");
return Ok((None, message));
}
match self.prepared_statements.get(&name) {
Some(prepared_stmt) => {
let describe = describe.rename(&prepared_stmt.name);
debug!(
"Rewrote describe `{}` to `{}`",
name, describe.statement_name
);
Ok((Some(name), describe.try_into()?))
}
None => {
debug!("Got describe for unknown prepared statement {:?}", describe);
Ok((None, message))
}
}
}
/// Release the server from the client: it can't cancel its queries anymore.
pub fn release(&self) {
let mut guard = self.client_server_map.lock();