Compare commits

...

3 Commits

Author SHA1 Message Date
Lev Kokotov
fd3623ff13 mm 2023-04-22 08:02:20 -07:00
Lev Kokotov
088f1a7dae remove debug msg 2023-04-22 07:47:19 -07:00
Lev Kokotov
ab7ac16974 reqs 2023-04-22 07:40:21 -07:00
6 changed files with 143 additions and 14 deletions

View File

@@ -110,6 +110,10 @@ python3 tests/python/tests.py || exit 1
start_pgcat "info" start_pgcat "info"
python3 tests/python/async_test.py
start_pgcat "info"
# Admin tests # Admin tests
export PGPASSWORD=admin_pass export PGPASSWORD=admin_pass
psql -U admin_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null psql -U admin_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null

View File

@@ -932,7 +932,7 @@ where
} }
// Grab a server from the pool. // Grab a server from the pool.
let connection = match pool let mut connection = match pool
.get(query_router.shard(), query_router.role(), &self.stats) .get(query_router.shard(), query_router.role(), &self.stats)
.await .await
{ {
@@ -975,9 +975,8 @@ where
} }
}; };
let mut reference = connection.0; let server = &mut *connection.0;
let address = connection.1; let address = connection.1;
let server = &mut *reference;
// Server is assigned to the client in case the client wants to // Server is assigned to the client in case the client wants to
// cancel a query later. // cancel a query later.
@@ -1000,6 +999,7 @@ where
// Set application_name. // Set application_name.
server.set_name(&self.application_name).await?; server.set_name(&self.application_name).await?;
server.switch_async(false);
let mut initial_message = Some(message); let mut initial_message = Some(message);
@@ -1019,12 +1019,37 @@ where
None => { None => {
trace!("Waiting for message inside transaction or in session mode"); trace!("Waiting for message inside transaction or in session mode");
match tokio::time::timeout( let message = tokio::select! {
message = tokio::time::timeout(
idle_client_timeout_duration, idle_client_timeout_duration,
read_message(&mut self.read), read_message(&mut self.read),
) ) => message,
.await
{ server_message = server.recv() => {
debug!("Got async message");
let server_message = match server_message {
Ok(message) => message,
Err(err) => {
pool.ban(&address, BanReason::MessageReceiveFailed, Some(&self.stats));
server.mark_bad();
return Err(err);
}
};
match write_all_half(&mut self.write, &server_message).await {
Ok(_) => (),
Err(err) => {
server.mark_bad();
return Err(err);
}
};
continue;
}
};
match message {
Ok(Ok(message)) => message, Ok(Ok(message)) => message,
Ok(Err(err)) => { Ok(Err(err)) => {
// Client disconnected inside a transaction. // Client disconnected inside a transaction.
@@ -1141,9 +1166,14 @@ where
// Sync // Sync
// Frontend (client) is asking for the query result now. // Frontend (client) is asking for the query result now.
'S' => { 'S' | 'H' => {
debug!("Sending query to server"); debug!("Sending query to server");
if code == 'H' {
server.switch_async(true);
debug!("Client requested flush, going async");
}
self.buffer.put(&message[..]); self.buffer.put(&message[..]);
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;

View File

@@ -777,6 +777,7 @@ impl ConnectionPool {
self.databases.len() self.databases.len()
} }
/// Retrieve all bans for all servers.
pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> { pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> {
let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new(); let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new();
let guard = self.banlist.read(); let guard = self.banlist.read();
@@ -788,7 +789,7 @@ impl ConnectionPool {
return bans; return bans;
} }
/// Get the address from the host url /// Get the address from the host url.
pub fn get_addresses_from_host(&self, host: &str) -> Vec<Address> { pub fn get_addresses_from_host(&self, host: &str) -> Vec<Address> {
let mut addresses = Vec::new(); let mut addresses = Vec::new();
for shard in 0..self.shards() { for shard in 0..self.shards() {
@@ -827,10 +828,13 @@ impl ConnectionPool {
&self.addresses[shard][server] &self.addresses[shard][server]
} }
/// Get server settings retrieved at connection setup.
pub fn server_info(&self) -> BytesMut { pub fn server_info(&self) -> BytesMut {
self.server_info.read().clone() self.server_info.read().clone()
} }
/// Calculate how many used connections in the pool
/// for the given server.
fn busy_connection_count(&self, address: &Address) -> u32 { fn busy_connection_count(&self, address: &Address) -> u32 {
let state = self.pool_state(address.shard, address.address_index); let state = self.pool_state(address.shard, address.address_index);
let idle = state.idle_connections; let idle = state.idle_connections;

View File

@@ -38,6 +38,7 @@ pub struct Server {
/// Our server response buffer. We buffer data before we give it to the client. /// Our server response buffer. We buffer data before we give it to the client.
buffer: BytesMut, buffer: BytesMut,
is_async: bool,
/// Server information the server sent us over on startup. /// Server information the server sent us over on startup.
server_info: BytesMut, server_info: BytesMut,
@@ -450,6 +451,7 @@ impl Server {
read: BufReader::new(read), read: BufReader::new(read),
write, write,
buffer: BytesMut::with_capacity(8196), buffer: BytesMut::with_capacity(8196),
is_async: false,
server_info, server_info,
process_id, process_id,
secret_key, secret_key,
@@ -537,6 +539,16 @@ impl Server {
} }
} }
/// Switch to async mode, flushing messages as soon
/// as we receive them without buffering or waiting for "ReadyForQuery".
pub fn switch_async(&mut self, on: bool) {
if on {
self.is_async = true;
} else {
self.is_async = false;
}
}
/// Receive data from the server in response to a client request. /// Receive data from the server in response to a client request.
/// This method must be called multiple times while `self.is_data_available()` is true /// This method must be called multiple times while `self.is_data_available()` is true
/// in order to receive all data the server has to offer. /// in order to receive all data the server has to offer.
@@ -632,7 +644,10 @@ impl Server {
// DataRow // DataRow
'D' => { 'D' => {
// More data is available after this message, this is not the end of the reply. // More data is available after this message, this is not the end of the reply.
// If we're async, flush to client now.
if !self.is_async {
self.data_available = true; self.data_available = true;
}
// Don't flush yet, the more we buffer, the faster this goes...up to a limit. // Don't flush yet, the more we buffer, the faster this goes...up to a limit.
if self.buffer.len() >= 8196 { if self.buffer.len() >= 8196 {
@@ -645,7 +660,10 @@ impl Server {
// CopyOutResponse: copy is starting from the server to the client. // CopyOutResponse: copy is starting from the server to the client.
'H' => { 'H' => {
// If we're in async mode, flush now.
if !self.is_async {
self.data_available = true; self.data_available = true;
}
break; break;
} }
@@ -665,6 +683,10 @@ impl Server {
// Keep buffering until ReadyForQuery shows up. // Keep buffering until ReadyForQuery shows up.
_ => (), _ => (),
}; };
if self.is_async {
break;
}
} }
let bytes = self.buffer.clone(); let bytes = self.buffer.clone();

View File

@@ -0,0 +1,60 @@
import psycopg2
import asyncio
import asyncpg
PGCAT_HOST = "127.0.0.1"
PGCAT_PORT = "6432"
def regular_main():
# Connect to the PostgreSQL database
conn = psycopg2.connect(
host=PGCAT_HOST,
database="sharded_db",
user="sharding_user",
password="sharding_user",
port=PGCAT_PORT,
)
# Open a cursor to perform database operations
cur = conn.cursor()
# Execute a SQL query
cur.execute("SELECT 1")
# Fetch the results
rows = cur.fetchall()
# Print the results
for row in rows:
print(row[0])
# Close the cursor and the database connection
cur.close()
conn.close()
async def main():
# Connect to the PostgreSQL database
conn = await asyncpg.connect(
host=PGCAT_HOST,
database="sharded_db",
user="sharding_user",
password="sharding_user",
port=PGCAT_PORT,
)
# Execute a SQL query
for _ in range(25):
rows = await conn.fetch("SELECT 1")
# Print the results
for row in rows:
print(row[0])
# Close the database connection
await conn.close()
regular_main()
asyncio.run(main())

View File

@@ -1,2 +1,11 @@
psycopg2==2.9.3 asyncio==3.4.3
asyncpg==0.27.0
black==23.3.0
click==8.1.3
mypy-extensions==1.0.0
packaging==23.1
pathspec==0.11.1
platformdirs==3.2.0
psutil==5.9.1 psutil==5.9.1
psycopg2==2.9.3
tomli==2.0.1