mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
Compare commits
3 Commits
mostafa_de
...
levkk-asyn
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd3623ff13 | ||
|
|
088f1a7dae | ||
|
|
ab7ac16974 |
@@ -110,6 +110,10 @@ python3 tests/python/tests.py || exit 1
|
||||
|
||||
start_pgcat "info"
|
||||
|
||||
python3 tests/python/async_test.py
|
||||
|
||||
start_pgcat "info"
|
||||
|
||||
# Admin tests
|
||||
export PGPASSWORD=admin_pass
|
||||
psql -U admin_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
|
||||
|
||||
@@ -932,7 +932,7 @@ where
|
||||
}
|
||||
|
||||
// Grab a server from the pool.
|
||||
let connection = match pool
|
||||
let mut connection = match pool
|
||||
.get(query_router.shard(), query_router.role(), &self.stats)
|
||||
.await
|
||||
{
|
||||
@@ -975,9 +975,8 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
let mut reference = connection.0;
|
||||
let server = &mut *connection.0;
|
||||
let address = connection.1;
|
||||
let server = &mut *reference;
|
||||
|
||||
// Server is assigned to the client in case the client wants to
|
||||
// cancel a query later.
|
||||
@@ -1000,6 +999,7 @@ where
|
||||
|
||||
// Set application_name.
|
||||
server.set_name(&self.application_name).await?;
|
||||
server.switch_async(false);
|
||||
|
||||
let mut initial_message = Some(message);
|
||||
|
||||
@@ -1019,12 +1019,37 @@ where
|
||||
None => {
|
||||
trace!("Waiting for message inside transaction or in session mode");
|
||||
|
||||
match tokio::time::timeout(
|
||||
idle_client_timeout_duration,
|
||||
read_message(&mut self.read),
|
||||
)
|
||||
.await
|
||||
{
|
||||
let message = tokio::select! {
|
||||
message = tokio::time::timeout(
|
||||
idle_client_timeout_duration,
|
||||
read_message(&mut self.read),
|
||||
) => message,
|
||||
|
||||
server_message = server.recv() => {
|
||||
debug!("Got async message");
|
||||
|
||||
let server_message = match server_message {
|
||||
Ok(message) => message,
|
||||
Err(err) => {
|
||||
pool.ban(&address, BanReason::MessageReceiveFailed, Some(&self.stats));
|
||||
server.mark_bad();
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
match write_all_half(&mut self.write, &server_message).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
server.mark_bad();
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match message {
|
||||
Ok(Ok(message)) => message,
|
||||
Ok(Err(err)) => {
|
||||
// Client disconnected inside a transaction.
|
||||
@@ -1141,9 +1166,14 @@ where
|
||||
|
||||
// Sync
|
||||
// Frontend (client) is asking for the query result now.
|
||||
'S' => {
|
||||
'S' | 'H' => {
|
||||
debug!("Sending query to server");
|
||||
|
||||
if code == 'H' {
|
||||
server.switch_async(true);
|
||||
debug!("Client requested flush, going async");
|
||||
}
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
|
||||
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
|
||||
|
||||
@@ -777,6 +777,7 @@ impl ConnectionPool {
|
||||
self.databases.len()
|
||||
}
|
||||
|
||||
/// Retrieve all bans for all servers.
|
||||
pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> {
|
||||
let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new();
|
||||
let guard = self.banlist.read();
|
||||
@@ -788,7 +789,7 @@ impl ConnectionPool {
|
||||
return bans;
|
||||
}
|
||||
|
||||
/// Get the address from the host url
|
||||
/// Get the address from the host url.
|
||||
pub fn get_addresses_from_host(&self, host: &str) -> Vec<Address> {
|
||||
let mut addresses = Vec::new();
|
||||
for shard in 0..self.shards() {
|
||||
@@ -827,10 +828,13 @@ impl ConnectionPool {
|
||||
&self.addresses[shard][server]
|
||||
}
|
||||
|
||||
/// Get server settings retrieved at connection setup.
|
||||
pub fn server_info(&self) -> BytesMut {
|
||||
self.server_info.read().clone()
|
||||
}
|
||||
|
||||
/// Calculate how many used connections in the pool
|
||||
/// for the given server.
|
||||
fn busy_connection_count(&self, address: &Address) -> u32 {
|
||||
let state = self.pool_state(address.shard, address.address_index);
|
||||
let idle = state.idle_connections;
|
||||
|
||||
@@ -38,6 +38,7 @@ pub struct Server {
|
||||
|
||||
/// Our server response buffer. We buffer data before we give it to the client.
|
||||
buffer: BytesMut,
|
||||
is_async: bool,
|
||||
|
||||
/// Server information the server sent us over on startup.
|
||||
server_info: BytesMut,
|
||||
@@ -450,6 +451,7 @@ impl Server {
|
||||
read: BufReader::new(read),
|
||||
write,
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
is_async: false,
|
||||
server_info,
|
||||
process_id,
|
||||
secret_key,
|
||||
@@ -537,6 +539,16 @@ impl Server {
|
||||
}
|
||||
}
|
||||
|
||||
/// Switch to async mode, flushing messages as soon
|
||||
/// as we receive them without buffering or waiting for "ReadyForQuery".
|
||||
pub fn switch_async(&mut self, on: bool) {
|
||||
if on {
|
||||
self.is_async = true;
|
||||
} else {
|
||||
self.is_async = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Receive data from the server in response to a client request.
|
||||
/// This method must be called multiple times while `self.is_data_available()` is true
|
||||
/// in order to receive all data the server has to offer.
|
||||
@@ -632,7 +644,10 @@ impl Server {
|
||||
// DataRow
|
||||
'D' => {
|
||||
// More data is available after this message, this is not the end of the reply.
|
||||
self.data_available = true;
|
||||
// If we're async, flush to client now.
|
||||
if !self.is_async {
|
||||
self.data_available = true;
|
||||
}
|
||||
|
||||
// Don't flush yet, the more we buffer, the faster this goes...up to a limit.
|
||||
if self.buffer.len() >= 8196 {
|
||||
@@ -645,7 +660,10 @@ impl Server {
|
||||
|
||||
// CopyOutResponse: copy is starting from the server to the client.
|
||||
'H' => {
|
||||
self.data_available = true;
|
||||
// If we're in async mode, flush now.
|
||||
if !self.is_async {
|
||||
self.data_available = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -665,6 +683,10 @@ impl Server {
|
||||
// Keep buffering until ReadyForQuery shows up.
|
||||
_ => (),
|
||||
};
|
||||
|
||||
if self.is_async {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let bytes = self.buffer.clone();
|
||||
|
||||
60
tests/python/async_test.py
Normal file
60
tests/python/async_test.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import psycopg2
|
||||
import asyncio
|
||||
import asyncpg
|
||||
|
||||
PGCAT_HOST = "127.0.0.1"
|
||||
PGCAT_PORT = "6432"
|
||||
|
||||
|
||||
def regular_main():
|
||||
# Connect to the PostgreSQL database
|
||||
conn = psycopg2.connect(
|
||||
host=PGCAT_HOST,
|
||||
database="sharded_db",
|
||||
user="sharding_user",
|
||||
password="sharding_user",
|
||||
port=PGCAT_PORT,
|
||||
)
|
||||
|
||||
# Open a cursor to perform database operations
|
||||
cur = conn.cursor()
|
||||
|
||||
# Execute a SQL query
|
||||
cur.execute("SELECT 1")
|
||||
|
||||
# Fetch the results
|
||||
rows = cur.fetchall()
|
||||
|
||||
# Print the results
|
||||
for row in rows:
|
||||
print(row[0])
|
||||
|
||||
# Close the cursor and the database connection
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
async def main():
|
||||
# Connect to the PostgreSQL database
|
||||
conn = await asyncpg.connect(
|
||||
host=PGCAT_HOST,
|
||||
database="sharded_db",
|
||||
user="sharding_user",
|
||||
password="sharding_user",
|
||||
port=PGCAT_PORT,
|
||||
)
|
||||
|
||||
# Execute a SQL query
|
||||
for _ in range(25):
|
||||
rows = await conn.fetch("SELECT 1")
|
||||
|
||||
# Print the results
|
||||
for row in rows:
|
||||
print(row[0])
|
||||
|
||||
# Close the database connection
|
||||
await conn.close()
|
||||
|
||||
|
||||
regular_main()
|
||||
asyncio.run(main())
|
||||
@@ -1,2 +1,11 @@
|
||||
asyncio==3.4.3
|
||||
asyncpg==0.27.0
|
||||
black==23.3.0
|
||||
click==8.1.3
|
||||
mypy-extensions==1.0.0
|
||||
packaging==23.1
|
||||
pathspec==0.11.1
|
||||
platformdirs==3.2.0
|
||||
psutil==5.9.1
|
||||
psycopg2==2.9.3
|
||||
psutil==5.9.1
|
||||
tomli==2.0.1
|
||||
|
||||
Reference in New Issue
Block a user