Buffer client CopyData messages (#284)

Buffers CopyData messages and removes buffer clone for the sync message
This commit is contained in:
zainkabani
2023-01-17 20:39:55 -05:00
committed by GitHub
parent 7894bba59b
commit 85ac3ef9a5
4 changed files with 44 additions and 34 deletions

View File

@@ -171,7 +171,7 @@ where
res.put_i32(5); res.put_i32(5);
res.put_u8(b'I'); res.put_u8(b'I');
write_all_half(stream, res).await write_all_half(stream, &res).await
} }
/// Show PgCat version. /// Show PgCat version.
@@ -189,7 +189,7 @@ where
res.put_i32(5); res.put_i32(5);
res.put_u8(b'I'); res.put_u8(b'I');
write_all_half(stream, res).await write_all_half(stream, &res).await
} }
/// Show utilization of connection pools for each shard and replicas. /// Show utilization of connection pools for each shard and replicas.
@@ -250,7 +250,7 @@ where
res.put_i32(5); res.put_i32(5);
res.put_u8(b'I'); res.put_u8(b'I');
write_all_half(stream, res).await write_all_half(stream, &res).await
} }
/// Show shards and replicas. /// Show shards and replicas.
@@ -317,7 +317,7 @@ where
res.put_i32(5); res.put_i32(5);
res.put_u8(b'I'); res.put_u8(b'I');
write_all_half(stream, res).await write_all_half(stream, &res).await
} }
/// Ignore any SET commands the client sends. /// Ignore any SET commands the client sends.
@@ -349,7 +349,7 @@ where
res.put_i32(5); res.put_i32(5);
res.put_u8(b'I'); res.put_u8(b'I');
write_all_half(stream, res).await write_all_half(stream, &res).await
} }
/// Shows current configuration. /// Shows current configuration.
@@ -395,7 +395,7 @@ where
res.put_i32(5); res.put_i32(5);
res.put_u8(b'I'); res.put_u8(b'I');
write_all_half(stream, res).await write_all_half(stream, &res).await
} }
/// Show shard and replicas statistics. /// Show shard and replicas statistics.
@@ -455,7 +455,7 @@ where
res.put_i32(5); res.put_i32(5);
res.put_u8(b'I'); res.put_u8(b'I');
write_all_half(stream, res).await write_all_half(stream, &res).await
} }
/// Show currently connected clients /// Show currently connected clients
@@ -505,7 +505,7 @@ where
res.put_i32(5); res.put_i32(5);
res.put_u8(b'I'); res.put_u8(b'I');
write_all_half(stream, res).await write_all_half(stream, &res).await
} }
/// Show currently connected servers /// Show currently connected servers
@@ -559,5 +559,5 @@ where
res.put_i32(5); res.put_i32(5);
res.put_u8(b'I'); res.put_u8(b'I');
write_all_half(stream, res).await write_all_half(stream, &res).await
} }

View File

@@ -861,7 +861,7 @@ where
'Q' => { 'Q' => {
debug!("Sending query to server"); debug!("Sending query to server");
self.send_and_receive_loop(code, message, server, &address, &pool) self.send_and_receive_loop(code, Some(&message), server, &address, &pool)
.await?; .await?;
if !server.in_transaction() { if !server.in_transaction() {
@@ -931,14 +931,8 @@ where
} }
} }
self.send_and_receive_loop( self.send_and_receive_loop(code, None, server, &address, &pool)
code, .await?;
self.buffer.clone(),
server,
&address,
&pool,
)
.await?;
self.buffer.clear(); self.buffer.clear();
@@ -955,21 +949,32 @@ where
// CopyData // CopyData
'd' => { 'd' => {
// Forward the data to the server, self.buffer.put(&message[..]);
// don't buffer it since it can be rather large.
self.send_server_message(server, message, &address, &pool) // Want to limit buffer size
.await?; if self.buffer.len() > 8196 {
// Forward the data to the server,
self.send_server_message(server, &self.buffer, &address, &pool)
.await?;
self.buffer.clear();
}
} }
// CopyDone or CopyFail // CopyDone or CopyFail
// Copy is done, successfully or not. // Copy is done, successfully or not.
'c' | 'f' => { 'c' | 'f' => {
self.send_server_message(server, message, &address, &pool) // We may already have some copy data in the buffer, add this message to buffer
self.buffer.put(&message[..]);
self.send_server_message(server, &self.buffer, &address, &pool)
.await?; .await?;
// Clear the buffer
self.buffer.clear();
let response = self.receive_server_message(server, &address, &pool).await?; let response = self.receive_server_message(server, &address, &pool).await?;
match write_all_half(&mut self.write, response).await { match write_all_half(&mut self.write, &response).await {
Ok(_) => (), Ok(_) => (),
Err(err) => { Err(err) => {
server.mark_bad(); server.mark_bad();
@@ -1016,13 +1021,18 @@ where
async fn send_and_receive_loop( async fn send_and_receive_loop(
&mut self, &mut self,
code: char, code: char,
message: BytesMut, message: Option<&BytesMut>,
server: &mut Server, server: &mut Server,
address: &Address, address: &Address,
pool: &ConnectionPool, pool: &ConnectionPool,
) -> Result<(), Error> { ) -> Result<(), Error> {
debug!("Sending {} to server", code); debug!("Sending {} to server", code);
let message = match message {
Some(message) => message,
None => &self.buffer,
};
self.send_server_message(server, message, address, pool) self.send_server_message(server, message, address, pool)
.await?; .await?;
@@ -1032,7 +1042,7 @@ where
loop { loop {
let response = self.receive_server_message(server, address, pool).await?; let response = self.receive_server_message(server, address, pool).await?;
match write_all_half(&mut self.write, response).await { match write_all_half(&mut self.write, &response).await {
Ok(_) => (), Ok(_) => (),
Err(err) => { Err(err) => {
server.mark_bad(); server.mark_bad();
@@ -1058,7 +1068,7 @@ where
async fn send_server_message( async fn send_server_message(
&self, &self,
server: &mut Server, server: &mut Server,
message: BytesMut, message: &BytesMut,
address: &Address, address: &Address,
pool: &ConnectionPool, pool: &ConnectionPool,
) -> Result<(), Error> { ) -> Result<(), Error> {

View File

@@ -258,7 +258,7 @@ where
res.put_i32(len); res.put_i32(len);
res.put_slice(&set_complete[..]); res.put_slice(&set_complete[..]);
write_all_half(stream, res).await?; write_all_half(stream, &res).await?;
ready_for_query(stream).await ready_for_query(stream).await
} }
@@ -308,7 +308,7 @@ where
res.put_i32(error.len() as i32 + 4); res.put_i32(error.len() as i32 + 4);
res.put(error); res.put(error);
write_all_half(stream, res).await write_all_half(stream, &res).await
} }
pub async fn wrong_password<S>(stream: &mut S, user: &str) -> Result<(), Error> pub async fn wrong_password<S>(stream: &mut S, user: &str) -> Result<(), Error>
@@ -370,7 +370,7 @@ where
// CommandComplete // CommandComplete
res.put(command_complete("SELECT 1")); res.put(command_complete("SELECT 1"));
write_all_half(stream, res).await?; write_all_half(stream, &res).await?;
ready_for_query(stream).await ready_for_query(stream).await
} }
@@ -459,11 +459,11 @@ where
} }
/// Write all the data in the buffer to the TcpStream, write owned half (see mpsc). /// Write all the data in the buffer to the TcpStream, write owned half (see mpsc).
pub async fn write_all_half<S>(stream: &mut S, buf: BytesMut) -> Result<(), Error> pub async fn write_all_half<S>(stream: &mut S, buf: &BytesMut) -> Result<(), Error>
where where
S: tokio::io::AsyncWrite + std::marker::Unpin, S: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
match stream.write_all(&buf).await { match stream.write_all(buf).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))), Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))),
} }

View File

@@ -381,7 +381,7 @@ impl Server {
} }
/// Send messages to the server from the client. /// Send messages to the server from the client.
pub async fn send(&mut self, messages: BytesMut) -> Result<(), Error> { pub async fn send(&mut self, messages: &BytesMut) -> Result<(), Error> {
self.stats.data_sent(messages.len(), self.server_id); self.stats.data_sent(messages.len(), self.server_id);
match write_all_half(&mut self.write, messages).await { match write_all_half(&mut self.write, messages).await {
@@ -593,7 +593,7 @@ impl Server {
pub async fn query(&mut self, query: &str) -> Result<(), Error> { pub async fn query(&mut self, query: &str) -> Result<(), Error> {
let query = simple_query(query); let query = simple_query(query);
self.send(query).await?; self.send(&query).await?;
loop { loop {
let _ = self.recv().await?; let _ = self.recv().await?;