mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 09:26:30 +00:00
Compare commits
3 Commits
mostafa_ad
...
levkk-asyn
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd3623ff13 | ||
|
|
088f1a7dae | ||
|
|
ab7ac16974 |
@@ -110,6 +110,10 @@ python3 tests/python/tests.py || exit 1
|
|||||||
|
|
||||||
start_pgcat "info"
|
start_pgcat "info"
|
||||||
|
|
||||||
|
python3 tests/python/async_test.py
|
||||||
|
|
||||||
|
start_pgcat "info"
|
||||||
|
|
||||||
# Admin tests
|
# Admin tests
|
||||||
export PGPASSWORD=admin_pass
|
export PGPASSWORD=admin_pass
|
||||||
psql -U admin_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
|
psql -U admin_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
|
||||||
|
|||||||
@@ -932,7 +932,7 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Grab a server from the pool.
|
// Grab a server from the pool.
|
||||||
let connection = match pool
|
let mut connection = match pool
|
||||||
.get(query_router.shard(), query_router.role(), &self.stats)
|
.get(query_router.shard(), query_router.role(), &self.stats)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
@@ -975,9 +975,8 @@ where
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut reference = connection.0;
|
let server = &mut *connection.0;
|
||||||
let address = connection.1;
|
let address = connection.1;
|
||||||
let server = &mut *reference;
|
|
||||||
|
|
||||||
// Server is assigned to the client in case the client wants to
|
// Server is assigned to the client in case the client wants to
|
||||||
// cancel a query later.
|
// cancel a query later.
|
||||||
@@ -1000,6 +999,7 @@ where
|
|||||||
|
|
||||||
// Set application_name.
|
// Set application_name.
|
||||||
server.set_name(&self.application_name).await?;
|
server.set_name(&self.application_name).await?;
|
||||||
|
server.switch_async(false);
|
||||||
|
|
||||||
let mut initial_message = Some(message);
|
let mut initial_message = Some(message);
|
||||||
|
|
||||||
@@ -1019,12 +1019,37 @@ where
|
|||||||
None => {
|
None => {
|
||||||
trace!("Waiting for message inside transaction or in session mode");
|
trace!("Waiting for message inside transaction or in session mode");
|
||||||
|
|
||||||
match tokio::time::timeout(
|
let message = tokio::select! {
|
||||||
idle_client_timeout_duration,
|
message = tokio::time::timeout(
|
||||||
read_message(&mut self.read),
|
idle_client_timeout_duration,
|
||||||
)
|
read_message(&mut self.read),
|
||||||
.await
|
) => message,
|
||||||
{
|
|
||||||
|
server_message = server.recv() => {
|
||||||
|
debug!("Got async message");
|
||||||
|
|
||||||
|
let server_message = match server_message {
|
||||||
|
Ok(message) => message,
|
||||||
|
Err(err) => {
|
||||||
|
pool.ban(&address, BanReason::MessageReceiveFailed, Some(&self.stats));
|
||||||
|
server.mark_bad();
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match write_all_half(&mut self.write, &server_message).await {
|
||||||
|
Ok(_) => (),
|
||||||
|
Err(err) => {
|
||||||
|
server.mark_bad();
|
||||||
|
return Err(err);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match message {
|
||||||
Ok(Ok(message)) => message,
|
Ok(Ok(message)) => message,
|
||||||
Ok(Err(err)) => {
|
Ok(Err(err)) => {
|
||||||
// Client disconnected inside a transaction.
|
// Client disconnected inside a transaction.
|
||||||
@@ -1141,9 +1166,14 @@ where
|
|||||||
|
|
||||||
// Sync
|
// Sync
|
||||||
// Frontend (client) is asking for the query result now.
|
// Frontend (client) is asking for the query result now.
|
||||||
'S' => {
|
'S' | 'H' => {
|
||||||
debug!("Sending query to server");
|
debug!("Sending query to server");
|
||||||
|
|
||||||
|
if code == 'H' {
|
||||||
|
server.switch_async(true);
|
||||||
|
debug!("Client requested flush, going async");
|
||||||
|
}
|
||||||
|
|
||||||
self.buffer.put(&message[..]);
|
self.buffer.put(&message[..]);
|
||||||
|
|
||||||
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
|
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
|
||||||
|
|||||||
@@ -777,6 +777,7 @@ impl ConnectionPool {
|
|||||||
self.databases.len()
|
self.databases.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Retrieve all bans for all servers.
|
||||||
pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> {
|
pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> {
|
||||||
let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new();
|
let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new();
|
||||||
let guard = self.banlist.read();
|
let guard = self.banlist.read();
|
||||||
@@ -788,7 +789,7 @@ impl ConnectionPool {
|
|||||||
return bans;
|
return bans;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the address from the host url
|
/// Get the address from the host url.
|
||||||
pub fn get_addresses_from_host(&self, host: &str) -> Vec<Address> {
|
pub fn get_addresses_from_host(&self, host: &str) -> Vec<Address> {
|
||||||
let mut addresses = Vec::new();
|
let mut addresses = Vec::new();
|
||||||
for shard in 0..self.shards() {
|
for shard in 0..self.shards() {
|
||||||
@@ -827,10 +828,13 @@ impl ConnectionPool {
|
|||||||
&self.addresses[shard][server]
|
&self.addresses[shard][server]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get server settings retrieved at connection setup.
|
||||||
pub fn server_info(&self) -> BytesMut {
|
pub fn server_info(&self) -> BytesMut {
|
||||||
self.server_info.read().clone()
|
self.server_info.read().clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Calculate how many used connections in the pool
|
||||||
|
/// for the given server.
|
||||||
fn busy_connection_count(&self, address: &Address) -> u32 {
|
fn busy_connection_count(&self, address: &Address) -> u32 {
|
||||||
let state = self.pool_state(address.shard, address.address_index);
|
let state = self.pool_state(address.shard, address.address_index);
|
||||||
let idle = state.idle_connections;
|
let idle = state.idle_connections;
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ pub struct Server {
|
|||||||
|
|
||||||
/// Our server response buffer. We buffer data before we give it to the client.
|
/// Our server response buffer. We buffer data before we give it to the client.
|
||||||
buffer: BytesMut,
|
buffer: BytesMut,
|
||||||
|
is_async: bool,
|
||||||
|
|
||||||
/// Server information the server sent us over on startup.
|
/// Server information the server sent us over on startup.
|
||||||
server_info: BytesMut,
|
server_info: BytesMut,
|
||||||
@@ -450,6 +451,7 @@ impl Server {
|
|||||||
read: BufReader::new(read),
|
read: BufReader::new(read),
|
||||||
write,
|
write,
|
||||||
buffer: BytesMut::with_capacity(8196),
|
buffer: BytesMut::with_capacity(8196),
|
||||||
|
is_async: false,
|
||||||
server_info,
|
server_info,
|
||||||
process_id,
|
process_id,
|
||||||
secret_key,
|
secret_key,
|
||||||
@@ -537,6 +539,16 @@ impl Server {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Switch to async mode, flushing messages as soon
|
||||||
|
/// as we receive them without buffering or waiting for "ReadyForQuery".
|
||||||
|
pub fn switch_async(&mut self, on: bool) {
|
||||||
|
if on {
|
||||||
|
self.is_async = true;
|
||||||
|
} else {
|
||||||
|
self.is_async = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Receive data from the server in response to a client request.
|
/// Receive data from the server in response to a client request.
|
||||||
/// This method must be called multiple times while `self.is_data_available()` is true
|
/// This method must be called multiple times while `self.is_data_available()` is true
|
||||||
/// in order to receive all data the server has to offer.
|
/// in order to receive all data the server has to offer.
|
||||||
@@ -632,7 +644,10 @@ impl Server {
|
|||||||
// DataRow
|
// DataRow
|
||||||
'D' => {
|
'D' => {
|
||||||
// More data is available after this message, this is not the end of the reply.
|
// More data is available after this message, this is not the end of the reply.
|
||||||
self.data_available = true;
|
// If we're async, flush to client now.
|
||||||
|
if !self.is_async {
|
||||||
|
self.data_available = true;
|
||||||
|
}
|
||||||
|
|
||||||
// Don't flush yet, the more we buffer, the faster this goes...up to a limit.
|
// Don't flush yet, the more we buffer, the faster this goes...up to a limit.
|
||||||
if self.buffer.len() >= 8196 {
|
if self.buffer.len() >= 8196 {
|
||||||
@@ -645,7 +660,10 @@ impl Server {
|
|||||||
|
|
||||||
// CopyOutResponse: copy is starting from the server to the client.
|
// CopyOutResponse: copy is starting from the server to the client.
|
||||||
'H' => {
|
'H' => {
|
||||||
self.data_available = true;
|
// If we're in async mode, flush now.
|
||||||
|
if !self.is_async {
|
||||||
|
self.data_available = true;
|
||||||
|
}
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -665,6 +683,10 @@ impl Server {
|
|||||||
// Keep buffering until ReadyForQuery shows up.
|
// Keep buffering until ReadyForQuery shows up.
|
||||||
_ => (),
|
_ => (),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if self.is_async {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let bytes = self.buffer.clone();
|
let bytes = self.buffer.clone();
|
||||||
|
|||||||
60
tests/python/async_test.py
Normal file
60
tests/python/async_test.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
import psycopg2
|
||||||
|
import asyncio
|
||||||
|
import asyncpg
|
||||||
|
|
||||||
|
PGCAT_HOST = "127.0.0.1"
|
||||||
|
PGCAT_PORT = "6432"
|
||||||
|
|
||||||
|
|
||||||
|
def regular_main():
|
||||||
|
# Connect to the PostgreSQL database
|
||||||
|
conn = psycopg2.connect(
|
||||||
|
host=PGCAT_HOST,
|
||||||
|
database="sharded_db",
|
||||||
|
user="sharding_user",
|
||||||
|
password="sharding_user",
|
||||||
|
port=PGCAT_PORT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Open a cursor to perform database operations
|
||||||
|
cur = conn.cursor()
|
||||||
|
|
||||||
|
# Execute a SQL query
|
||||||
|
cur.execute("SELECT 1")
|
||||||
|
|
||||||
|
# Fetch the results
|
||||||
|
rows = cur.fetchall()
|
||||||
|
|
||||||
|
# Print the results
|
||||||
|
for row in rows:
|
||||||
|
print(row[0])
|
||||||
|
|
||||||
|
# Close the cursor and the database connection
|
||||||
|
cur.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# Connect to the PostgreSQL database
|
||||||
|
conn = await asyncpg.connect(
|
||||||
|
host=PGCAT_HOST,
|
||||||
|
database="sharded_db",
|
||||||
|
user="sharding_user",
|
||||||
|
password="sharding_user",
|
||||||
|
port=PGCAT_PORT,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Execute a SQL query
|
||||||
|
for _ in range(25):
|
||||||
|
rows = await conn.fetch("SELECT 1")
|
||||||
|
|
||||||
|
# Print the results
|
||||||
|
for row in rows:
|
||||||
|
print(row[0])
|
||||||
|
|
||||||
|
# Close the database connection
|
||||||
|
await conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
regular_main()
|
||||||
|
asyncio.run(main())
|
||||||
@@ -1,2 +1,11 @@
|
|||||||
psycopg2==2.9.3
|
asyncio==3.4.3
|
||||||
|
asyncpg==0.27.0
|
||||||
|
black==23.3.0
|
||||||
|
click==8.1.3
|
||||||
|
mypy-extensions==1.0.0
|
||||||
|
packaging==23.1
|
||||||
|
pathspec==0.11.1
|
||||||
|
platformdirs==3.2.0
|
||||||
psutil==5.9.1
|
psutil==5.9.1
|
||||||
|
psycopg2==2.9.3
|
||||||
|
tomli==2.0.1
|
||||||
|
|||||||
Reference in New Issue
Block a user