Don't send discard all when state is changed in transaction (#186)

* Don't send discard all when state is changed in transaction

* Remove unnecessary clone

* spelling

* Move transaction check to SET command

* Add test for set command in transaction

* type

* Update comments

* Update comments

* use moves instead of clones for initial message

* don't make message mutable

* Update unwrap

* but i'm not a wrapper

* Add set local test

* change continue
This commit is contained in:
zainkabani
2022-10-13 22:33:12 -04:00
committed by GitHub
parent eceb7f092e
commit 19f635881a
3 changed files with 67 additions and 29 deletions

View File

@@ -601,7 +601,7 @@ where
// in case the client is sending some custom protocol messages, e.g.
// SET SHARDING KEY TO 'bigint';
let mut message = tokio::select! {
let message = tokio::select! {
_ = self.shutdown.recv() => {
if !self.admin {
error_response_terminal(
@@ -792,6 +792,8 @@ where
// Set application_name.
server.set_name(&self.application_name).await?;
let mut initial_message = Some(message);
// Transaction loop. Multiple queries can be issued by the client here.
// The connection belongs to the client until the transaction is over,
// or until the client disconnects if we are in session mode.
@@ -799,40 +801,42 @@ where
// If the client is in session mode, no more custom protocol
// commands will be accepted.
loop {
let mut message = if message.len() == 0 {
trace!("Waiting for message inside transaction or in session mode");
let message = match initial_message {
None => {
trace!("Waiting for message inside transaction or in session mode");
match read_message(&mut self.read).await {
Ok(message) => message,
Err(err) => {
// Client disconnected inside a transaction.
// Clean up the server and re-use it.
server.checkin_cleanup().await?;
match read_message(&mut self.read).await {
Ok(message) => message,
Err(err) => {
// Client disconnected inside a transaction.
// Clean up the server and re-use it.
server.checkin_cleanup().await?;
return Err(err);
return Err(err);
}
}
}
} else {
let msg = message.clone();
message.clear();
msg
Some(message) => {
initial_message = None;
message
}
};
// The message will be forwarded to the server intact. We still would like to
// parse it below to figure out what to do with it.
let original = message.clone();
let code = message.get_u8() as char;
let _len = message.get_i32() as usize;
// Safe to unwrap because we know this message has a certain length and has the code
// This reads the first byte without advancing the internal pointer and mutating the bytes
let code = *message.get(0).unwrap() as char;
trace!("Message: {}", code);
match code {
// ReadyForQuery
// Query
'Q' => {
debug!("Sending query to server");
self.send_and_receive_loop(code, original, server, &address, &pool)
self.send_and_receive_loop(code, message, server, &address, &pool)
.await?;
if !server.in_transaction() {
@@ -858,25 +862,25 @@ where
// Parse
// The query with placeholders is here, e.g. `SELECT * FROM users WHERE email = $1 AND active = $2`.
'P' => {
self.buffer.put(&original[..]);
self.buffer.put(&message[..]);
}
// Bind
// The placeholder's replacements are here, e.g. 'user@email.com' and 'true'
'B' => {
self.buffer.put(&original[..]);
self.buffer.put(&message[..]);
}
// Describe
// Command a client can issue to describe a previously prepared named statement.
'D' => {
self.buffer.put(&original[..]);
self.buffer.put(&message[..]);
}
// Execute
// Execute a prepared statement prepared in `P` and bound in `B`.
'E' => {
self.buffer.put(&original[..]);
self.buffer.put(&message[..]);
}
// Sync
@@ -884,9 +888,8 @@ where
'S' => {
debug!("Sending query to server");
self.buffer.put(&original[..]);
self.buffer.put(&message[..]);
// Clone after freeze does not allocate
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
// Almost certainly true
@@ -929,14 +932,14 @@ where
'd' => {
// Forward the data to the server,
// don't buffer it since it can be rather large.
self.send_server_message(server, original, &address, &pool)
self.send_server_message(server, message, &address, &pool)
.await?;
}
// CopyDone or CopyFail
// Copy is done, successfully or not.
'c' | 'f' => {
self.send_server_message(server, original, &address, &pool)
self.send_server_message(server, message, &address, &pool)
.await?;
let response = self.receive_server_message(server, &address, &pool).await?;

View File

@@ -457,7 +457,17 @@ impl Server {
// which can leak between clients. This is a best effort to block bad clients
// from poisoning a transaction-mode pool by setting inappropriate session variables
match command_tag.as_str() {
"SET\0" | "PREPARE\0" => {
"SET\0" => {
// We don't detect set statements in transactions
// No great way to differentiate between set and set local
// As a result, we will miss cases when set statements are used in transactions
// This will reduce amount of discard statements sent
if !self.in_transaction {
debug!("Server connection marked for clean up");
self.needs_cleanup = true;
}
}
"PREPARE\0" => {
debug!("Server connection marked for clean up");
self.needs_cleanup = true;
}
@@ -595,7 +605,7 @@ impl Server {
self.query("ROLLBACK").await?;
}
// Client disconnected but it perfromed session-altering operations such as
// Client disconnected but it performed session-altering operations such as
// SET statement_timeout to 1 or create a prepared statement. We clear that
// to avoid leaking state between clients. For performance reasons we only
// send `DISCARD ALL` if we think the session is altered instead of just sending

View File

@@ -189,5 +189,30 @@ describe "Miscellaneous" do
expect(processes.primary.count_query("DISCARD ALL")).to eq(10)
end
end
context "transaction mode with transactions" do
let(:processes) { Helpers::Pgcat.single_shard_setup("sharded_db", 5, "transaction") }
it "Does not clear set statement state when declared in a transaction" do
10.times do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("SET SERVER ROLE to 'primary'")
conn.async_exec("BEGIN")
conn.async_exec("SET statement_timeout to 1000")
conn.async_exec("COMMIT")
conn.close
end
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
10.times do
conn = PG::connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("SET SERVER ROLE to 'primary'")
conn.async_exec("BEGIN")
conn.async_exec("SET LOCAL statement_timeout to 1000")
conn.async_exec("COMMIT")
conn.close
end
expect(processes.primary.count_query("DISCARD ALL")).to eq(0)
end
end
end
end