diff --git a/src/admin.rs b/src/admin.rs index 4460f98..5879114 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -171,7 +171,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show PgCat version. @@ -189,7 +189,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show utilization of connection pools for each shard and replicas. @@ -250,7 +250,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show shards and replicas. @@ -317,7 +317,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Ignore any SET commands the client sends. @@ -349,7 +349,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Shows current configuration. @@ -395,7 +395,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show shard and replicas statistics. @@ -455,7 +455,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show currently connected clients @@ -505,7 +505,7 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } /// Show currently connected servers @@ -559,5 +559,5 @@ where res.put_i32(5); res.put_u8(b'I'); - write_all_half(stream, res).await + write_all_half(stream, &res).await } diff --git a/src/client.rs b/src/client.rs index b55906b..cfe12c0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -861,7 +861,7 @@ where 'Q' => { debug!("Sending query to server"); - self.send_and_receive_loop(code, message, server, &address, &pool) + self.send_and_receive_loop(code, Some(&message), server, &address, &pool) .await?; if !server.in_transaction() { @@ -931,14 +931,8 @@ where } } - self.send_and_receive_loop( - code, - self.buffer.clone(), - server, - &address, - &pool, - ) - .await?; + self.send_and_receive_loop(code, None, server, &address, &pool) + .await?; self.buffer.clear(); @@ -955,21 +949,32 @@ where // CopyData 'd' => { - // Forward the data to the server, - // don't buffer it since it can be rather large. - self.send_server_message(server, message, &address, &pool) - .await?; + self.buffer.put(&message[..]); + + // Want to limit buffer size + if self.buffer.len() > 8196 { + // Forward the data to the server, + self.send_server_message(server, &self.buffer, &address, &pool) + .await?; + self.buffer.clear(); + } } // CopyDone or CopyFail // Copy is done, successfully or not. 'c' | 'f' => { - self.send_server_message(server, message, &address, &pool) + // We may already have some copy data in the buffer, add this message to buffer + self.buffer.put(&message[..]); + + self.send_server_message(server, &self.buffer, &address, &pool) .await?; + // Clear the buffer + self.buffer.clear(); + let response = self.receive_server_message(server, &address, &pool).await?; - match write_all_half(&mut self.write, response).await { + match write_all_half(&mut self.write, &response).await { Ok(_) => (), Err(err) => { server.mark_bad(); @@ -1016,13 +1021,18 @@ where async fn send_and_receive_loop( &mut self, code: char, - message: BytesMut, + message: Option<&BytesMut>, server: &mut Server, address: &Address, pool: &ConnectionPool, ) -> Result<(), Error> { debug!("Sending {} to server", code); + let message = match message { + Some(message) => message, + None => &self.buffer, + }; + self.send_server_message(server, message, address, pool) .await?; @@ -1032,7 +1042,7 @@ where loop { let response = self.receive_server_message(server, address, pool).await?; - match write_all_half(&mut self.write, response).await { + match write_all_half(&mut self.write, &response).await { Ok(_) => (), Err(err) => { server.mark_bad(); @@ -1058,7 +1068,7 @@ where async fn send_server_message( &self, server: &mut Server, - message: BytesMut, + message: &BytesMut, address: &Address, pool: &ConnectionPool, ) -> Result<(), Error> { diff --git a/src/messages.rs b/src/messages.rs index e831550..45a827c 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -258,7 +258,7 @@ where res.put_i32(len); res.put_slice(&set_complete[..]); - write_all_half(stream, res).await?; + write_all_half(stream, &res).await?; ready_for_query(stream).await } @@ -308,7 +308,7 @@ where res.put_i32(error.len() as i32 + 4); res.put(error); - write_all_half(stream, res).await + write_all_half(stream, &res).await } pub async fn wrong_password(stream: &mut S, user: &str) -> Result<(), Error> @@ -370,7 +370,7 @@ where // CommandComplete res.put(command_complete("SELECT 1")); - write_all_half(stream, res).await?; + write_all_half(stream, &res).await?; ready_for_query(stream).await } @@ -459,11 +459,11 @@ where } /// Write all the data in the buffer to the TcpStream, write owned half (see mpsc). -pub async fn write_all_half(stream: &mut S, buf: BytesMut) -> Result<(), Error> +pub async fn write_all_half(stream: &mut S, buf: &BytesMut) -> Result<(), Error> where S: tokio::io::AsyncWrite + std::marker::Unpin, { - match stream.write_all(&buf).await { + match stream.write_all(buf).await { Ok(_) => Ok(()), Err(_) => return Err(Error::SocketError(format!("Error writing to socket"))), } diff --git a/src/server.rs b/src/server.rs index 05a3b77..f2a6d38 100644 --- a/src/server.rs +++ b/src/server.rs @@ -381,7 +381,7 @@ impl Server { } /// Send messages to the server from the client. - pub async fn send(&mut self, messages: BytesMut) -> Result<(), Error> { + pub async fn send(&mut self, messages: &BytesMut) -> Result<(), Error> { self.stats.data_sent(messages.len(), self.server_id); match write_all_half(&mut self.write, messages).await { @@ -593,7 +593,7 @@ impl Server { pub async fn query(&mut self, query: &str) -> Result<(), Error> { let query = simple_query(query); - self.send(query).await?; + self.send(&query).await?; loop { let _ = self.recv().await?;