Buffer client CopyData messages (#284)

Buffers CopyData messages and removes buffer clone for the sync message
This commit is contained in:
zainkabani
2023-01-17 20:39:55 -05:00
committed by GitHub
parent 7894bba59b
commit 85ac3ef9a5
4 changed files with 44 additions and 34 deletions

View File

@@ -171,7 +171,7 @@ where
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
write_all_half(stream, &res).await
}
/// Show PgCat version.
@@ -189,7 +189,7 @@ where
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
write_all_half(stream, &res).await
}
/// Show utilization of connection pools for each shard and replicas.
@@ -250,7 +250,7 @@ where
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
write_all_half(stream, &res).await
}
/// Show shards and replicas.
@@ -317,7 +317,7 @@ where
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
write_all_half(stream, &res).await
}
/// Ignore any SET commands the client sends.
@@ -349,7 +349,7 @@ where
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
write_all_half(stream, &res).await
}
/// Shows current configuration.
@@ -395,7 +395,7 @@ where
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
write_all_half(stream, &res).await
}
/// Show shard and replicas statistics.
@@ -455,7 +455,7 @@ where
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
write_all_half(stream, &res).await
}
/// Show currently connected clients
@@ -505,7 +505,7 @@ where
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
write_all_half(stream, &res).await
}
/// Show currently connected servers
@@ -559,5 +559,5 @@ where
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
write_all_half(stream, &res).await
}

View File

@@ -861,7 +861,7 @@ where
'Q' => {
debug!("Sending query to server");
self.send_and_receive_loop(code, message, server, &address, &pool)
self.send_and_receive_loop(code, Some(&message), server, &address, &pool)
.await?;
if !server.in_transaction() {
@@ -931,14 +931,8 @@ where
}
}
self.send_and_receive_loop(
code,
self.buffer.clone(),
server,
&address,
&pool,
)
.await?;
self.send_and_receive_loop(code, None, server, &address, &pool)
.await?;
self.buffer.clear();
@@ -955,21 +949,32 @@ where
// CopyData
'd' => {
// Forward the data to the server,
// don't buffer it since it can be rather large.
self.send_server_message(server, message, &address, &pool)
.await?;
self.buffer.put(&message[..]);
// Want to limit buffer size
if self.buffer.len() > 8196 {
// Forward the data to the server,
self.send_server_message(server, &self.buffer, &address, &pool)
.await?;
self.buffer.clear();
}
}
// CopyDone or CopyFail
// Copy is done, successfully or not.
'c' | 'f' => {
self.send_server_message(server, message, &address, &pool)
// We may already have some copy data in the buffer, add this message to buffer
self.buffer.put(&message[..]);
self.send_server_message(server, &self.buffer, &address, &pool)
.await?;
// Clear the buffer
self.buffer.clear();
let response = self.receive_server_message(server, &address, &pool).await?;
match write_all_half(&mut self.write, response).await {
match write_all_half(&mut self.write, &response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
@@ -1016,13 +1021,18 @@ where
async fn send_and_receive_loop(
&mut self,
code: char,
message: BytesMut,
message: Option<&BytesMut>,
server: &mut Server,
address: &Address,
pool: &ConnectionPool,
) -> Result<(), Error> {
debug!("Sending {} to server", code);
let message = match message {
Some(message) => message,
None => &self.buffer,
};
self.send_server_message(server, message, address, pool)
.await?;
@@ -1032,7 +1042,7 @@ where
loop {
let response = self.receive_server_message(server, address, pool).await?;
match write_all_half(&mut self.write, response).await {
match write_all_half(&mut self.write, &response).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
@@ -1058,7 +1068,7 @@ where
async fn send_server_message(
&self,
server: &mut Server,
message: BytesMut,
message: &BytesMut,
address: &Address,
pool: &ConnectionPool,
) -> Result<(), Error> {

View File

@@ -258,7 +258,7 @@ where
res.put_i32(len);
res.put_slice(&set_complete[..]);
write_all_half(stream, res).await?;
write_all_half(stream, &res).await?;
ready_for_query(stream).await
}
@@ -308,7 +308,7 @@ where
res.put_i32(error.len() as i32 + 4);
res.put(error);
write_all_half(stream, res).await
write_all_half(stream, &res).await
}
pub async fn wrong_password<S>(stream: &mut S, user: &str) -> Result<(), Error>
@@ -370,7 +370,7 @@ where
// CommandComplete
res.put(command_complete("SELECT 1"));
write_all_half(stream, res).await?;
write_all_half(stream, &res).await?;
ready_for_query(stream).await
}
@@ -459,11 +459,11 @@ where
}
/// Write all the data in the buffer to the TcpStream, write owned half (see mpsc).
pub async fn write_all_half<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error>
pub async fn write_all_half<S>(stream: &mut S, buf: &BytesMut) -> Result<(), Error>
where
S: tokio::io::AsyncWrite + std::marker::Unpin,
{
match stream.write_all(&buf).await {
match stream.write_all(buf).await {
Ok(_) => Ok(()),
Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))),
}

View File

@@ -381,7 +381,7 @@ impl Server {
}
/// Send messages to the server from the client.
pub async fn send(&mut self, messages: BytesMut) -> Result<(), Error> {
pub async fn send(&mut self, messages: &BytesMut) -> Result<(), Error> {
self.stats.data_sent(messages.len(), self.server_id);
match write_all_half(&mut self.write, messages).await {
@@ -593,7 +593,7 @@ impl Server {
pub async fn query(&mut self, query: &str) -> Result<(), Error> {
let query = simple_query(query);
self.send(query).await?;
self.send(&query).await?;
loop {
let _ = self.recv().await?;