mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
Compare commits
3 Commits
levkk-star
...
levkk-asyn
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd3623ff13 | ||
|
|
088f1a7dae | ||
|
|
ab7ac16974 |
@@ -110,6 +110,10 @@ python3 tests/python/tests.py || exit 1
|
||||
|
||||
start_pgcat "info"
|
||||
|
||||
python3 tests/python/async_test.py
|
||||
|
||||
start_pgcat "info"
|
||||
|
||||
# Admin tests
|
||||
export PGPASSWORD=admin_pass
|
||||
psql -U admin_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
|
||||
|
||||
@@ -49,14 +49,6 @@ default: 30000 # milliseconds
|
||||
|
||||
How long an idle connection with a server is left open (ms).
|
||||
|
||||
### server_lifetime
|
||||
```
|
||||
path: general.server_lifetime
|
||||
default: 86400000 # 24 hours
|
||||
```
|
||||
|
||||
Max connection lifetime before it's closed, even if actively used.
|
||||
|
||||
### idle_client_in_transaction_timeout
|
||||
```
|
||||
path: general.idle_client_in_transaction_timeout
|
||||
|
||||
32
Cargo.lock
generated
32
Cargo.lock
generated
@@ -762,11 +762,9 @@ dependencies = [
|
||||
"once_cell",
|
||||
"parking_lot",
|
||||
"phf",
|
||||
"pin-project",
|
||||
"postgres-protocol",
|
||||
"rand",
|
||||
"regex",
|
||||
"rustls",
|
||||
"rustls-pemfile",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
@@ -778,7 +776,6 @@ dependencies = [
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"toml",
|
||||
"webpki-roots",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -823,26 +820,6 @@ dependencies = [
|
||||
"siphasher",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad29a609b6bcd67fee905812e544992d216af9d755757c05ed2d0e15a74c6ecc"
|
||||
dependencies = [
|
||||
"pin-project-internal",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-internal"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "069bdb1e05adc7a8990dce9cc75370895fbe4e3d58b9b73bf1aee56359344a55"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.9"
|
||||
@@ -1469,15 +1446,6 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.23.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "aa54963694b65584e170cf5dc46aeb4dcaa5584e652ff5f3952e56d66aff0125"
|
||||
dependencies = [
|
||||
"rustls-webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
|
||||
@@ -39,9 +39,6 @@ nix = "0.26.2"
|
||||
atomic_enum = "0.2.0"
|
||||
postgres-protocol = "0.6.5"
|
||||
fallible-iterator = "0.2"
|
||||
pin-project = "1"
|
||||
webpki-roots = "0.23"
|
||||
rustls = { version = "0.21", features = ["dangerous_configuration"] }
|
||||
|
||||
[target.'cfg(not(target_env = "msvc"))'.dependencies]
|
||||
jemallocator = "0.5.0"
|
||||
|
||||
15
pgcat.toml
15
pgcat.toml
@@ -23,9 +23,6 @@ connect_timeout = 5000 # milliseconds
|
||||
# How long an idle connection with a server is left open (ms).
|
||||
idle_timeout = 30000 # milliseconds
|
||||
|
||||
# Max connection lifetime before it's closed, even if actively used.
|
||||
server_lifetime = 86400000 # 24 hours
|
||||
|
||||
# How long a client is allowed to be idle while in a transaction (ms).
|
||||
idle_client_in_transaction_timeout = 0 # milliseconds
|
||||
|
||||
@@ -61,15 +58,9 @@ tcp_keepalives_count = 5
|
||||
tcp_keepalives_interval = 5
|
||||
|
||||
# Path to TLS Certificate file to use for TLS connections
|
||||
# tls_certificate = ".circleci/server.cert"
|
||||
# tls_certificate = "server.cert"
|
||||
# Path to TLS private key file to use for TLS connections
|
||||
# tls_private_key = ".circleci/server.key"
|
||||
|
||||
# Enable/disable server TLS
|
||||
server_tls = false
|
||||
|
||||
# Verify server certificate is completely authentic.
|
||||
verify_server_certificate = false
|
||||
# tls_private_key = "server.key"
|
||||
|
||||
# User name to access the virtual administrative database (pgbouncer or pgcat)
|
||||
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
|
||||
@@ -215,8 +206,6 @@ sharding_function = "pg_bigint_hash"
|
||||
username = "simple_user"
|
||||
password = "simple_user"
|
||||
pool_size = 5
|
||||
min_pool_size = 3
|
||||
server_lifetime = 60000
|
||||
statement_timeout = 0
|
||||
|
||||
[pools.simple_db.shards.0]
|
||||
|
||||
@@ -77,8 +77,6 @@ impl AuthPassthrough {
|
||||
pool_size: 1,
|
||||
statement_timeout: 0,
|
||||
pool_mode: None,
|
||||
server_lifetime: None,
|
||||
min_pool_size: None,
|
||||
};
|
||||
|
||||
let user = &address.username;
|
||||
|
||||
@@ -539,7 +539,6 @@ where
|
||||
Some(md5_hash_password(username, password, &salt))
|
||||
} else {
|
||||
if !get_config().is_auth_query_configured() {
|
||||
wrong_password(&mut write, username).await?;
|
||||
return Err(Error::ClientAuthImpossible(username.into()));
|
||||
}
|
||||
|
||||
@@ -566,8 +565,6 @@ where
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(Error::ClientAuthPassthroughError(
|
||||
err.to_string(),
|
||||
client_identifier,
|
||||
@@ -590,15 +587,7 @@ where
|
||||
client_identifier
|
||||
);
|
||||
|
||||
let fetched_hash = match refetch_auth_hash(&pool).await {
|
||||
Ok(fetched_hash) => fetched_hash,
|
||||
Err(err) => {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let fetched_hash = refetch_auth_hash(&pool).await?;
|
||||
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
|
||||
|
||||
// Ok password changed in server an auth is possible.
|
||||
@@ -943,7 +932,7 @@ where
|
||||
}
|
||||
|
||||
// Grab a server from the pool.
|
||||
let connection = match pool
|
||||
let mut connection = match pool
|
||||
.get(query_router.shard(), query_router.role(), &self.stats)
|
||||
.await
|
||||
{
|
||||
@@ -986,9 +975,8 @@ where
|
||||
}
|
||||
};
|
||||
|
||||
let mut reference = connection.0;
|
||||
let server = &mut *connection.0;
|
||||
let address = connection.1;
|
||||
let server = &mut *reference;
|
||||
|
||||
// Server is assigned to the client in case the client wants to
|
||||
// cancel a query later.
|
||||
@@ -1011,6 +999,7 @@ where
|
||||
|
||||
// Set application_name.
|
||||
server.set_name(&self.application_name).await?;
|
||||
server.switch_async(false);
|
||||
|
||||
let mut initial_message = Some(message);
|
||||
|
||||
@@ -1030,12 +1019,37 @@ where
|
||||
None => {
|
||||
trace!("Waiting for message inside transaction or in session mode");
|
||||
|
||||
match tokio::time::timeout(
|
||||
idle_client_timeout_duration,
|
||||
read_message(&mut self.read),
|
||||
)
|
||||
.await
|
||||
{
|
||||
let message = tokio::select! {
|
||||
message = tokio::time::timeout(
|
||||
idle_client_timeout_duration,
|
||||
read_message(&mut self.read),
|
||||
) => message,
|
||||
|
||||
server_message = server.recv() => {
|
||||
debug!("Got async message");
|
||||
|
||||
let server_message = match server_message {
|
||||
Ok(message) => message,
|
||||
Err(err) => {
|
||||
pool.ban(&address, BanReason::MessageReceiveFailed, Some(&self.stats));
|
||||
server.mark_bad();
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
match write_all_half(&mut self.write, &server_message).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
server.mark_bad();
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
match message {
|
||||
Ok(Ok(message)) => message,
|
||||
Ok(Err(err)) => {
|
||||
// Client disconnected inside a transaction.
|
||||
@@ -1152,9 +1166,14 @@ where
|
||||
|
||||
// Sync
|
||||
// Frontend (client) is asking for the query result now.
|
||||
'S' => {
|
||||
'S' | 'H' => {
|
||||
debug!("Sending query to server");
|
||||
|
||||
if code == 'H' {
|
||||
server.switch_async(true);
|
||||
debug!("Client requested flush, going async");
|
||||
}
|
||||
|
||||
self.buffer.put(&message[..]);
|
||||
|
||||
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char;
|
||||
|
||||
107
src/config.rs
107
src/config.rs
@@ -181,9 +181,7 @@ pub struct User {
|
||||
pub server_username: Option<String>,
|
||||
pub server_password: Option<String>,
|
||||
pub pool_size: u32,
|
||||
pub min_pool_size: Option<u32>,
|
||||
pub pool_mode: Option<PoolMode>,
|
||||
pub server_lifetime: Option<u64>,
|
||||
#[serde(default)] // 0
|
||||
pub statement_timeout: u64,
|
||||
}
|
||||
@@ -196,34 +194,12 @@ impl Default for User {
|
||||
server_username: None,
|
||||
server_password: None,
|
||||
pool_size: 15,
|
||||
min_pool_size: None,
|
||||
statement_timeout: 0,
|
||||
pool_mode: None,
|
||||
server_lifetime: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl User {
|
||||
fn validate(&self) -> Result<(), Error> {
|
||||
match self.min_pool_size {
|
||||
Some(min_pool_size) => {
|
||||
if min_pool_size > self.pool_size {
|
||||
error!(
|
||||
"min_pool_size of {} cannot be larger than pool_size of {}",
|
||||
min_pool_size, self.pool_size
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
|
||||
None => (),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// General configuration.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
|
||||
pub struct General {
|
||||
@@ -270,9 +246,6 @@ pub struct General {
|
||||
#[serde(default = "General::default_idle_client_in_transaction_timeout")]
|
||||
pub idle_client_in_transaction_timeout: u64,
|
||||
|
||||
#[serde(default = "General::default_server_lifetime")]
|
||||
pub server_lifetime: u64,
|
||||
|
||||
#[serde(default = "General::default_worker_threads")]
|
||||
pub worker_threads: usize,
|
||||
|
||||
@@ -281,13 +254,6 @@ pub struct General {
|
||||
|
||||
pub tls_certificate: Option<String>,
|
||||
pub tls_private_key: Option<String>,
|
||||
|
||||
#[serde(default)] // false
|
||||
pub server_tls: bool,
|
||||
|
||||
#[serde(default)] // false
|
||||
pub verify_server_certificate: bool,
|
||||
|
||||
pub admin_username: String,
|
||||
pub admin_password: String,
|
||||
|
||||
@@ -305,10 +271,6 @@ impl General {
|
||||
5432
|
||||
}
|
||||
|
||||
pub fn default_server_lifetime() -> u64 {
|
||||
1000 * 60 * 60 * 24 // 24 hours
|
||||
}
|
||||
|
||||
pub fn default_connect_timeout() -> u64 {
|
||||
1000
|
||||
}
|
||||
@@ -380,14 +342,11 @@ impl Default for General {
|
||||
autoreload: None,
|
||||
tls_certificate: None,
|
||||
tls_private_key: None,
|
||||
server_tls: false,
|
||||
verify_server_certificate: false,
|
||||
admin_username: String::from("admin"),
|
||||
admin_password: String::from("admin"),
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
server_lifetime: 1000 * 3600 * 24, // 24 hours,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -452,8 +411,6 @@ pub struct Pool {
|
||||
|
||||
pub idle_timeout: Option<u64>,
|
||||
|
||||
pub server_lifetime: Option<u64>,
|
||||
|
||||
pub sharding_function: ShardingFunction,
|
||||
|
||||
#[serde(default = "Pool::default_automatic_sharding_key")]
|
||||
@@ -558,10 +515,6 @@ impl Pool {
|
||||
None => None,
|
||||
};
|
||||
|
||||
for (_, user) in &self.users {
|
||||
user.validate()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -586,7 +539,6 @@ impl Default for Pool {
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
auth_query_password: None,
|
||||
server_lifetime: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -636,7 +588,7 @@ impl Shard {
|
||||
|
||||
if primary_count > 1 {
|
||||
error!(
|
||||
"Shard {} has more than one primary configured",
|
||||
"Shard {} has more than on primary configured",
|
||||
self.database
|
||||
);
|
||||
return Err(Error::BadConfig);
|
||||
@@ -839,10 +791,6 @@ impl Config {
|
||||
);
|
||||
info!("Shutdown timeout: {}ms", self.general.shutdown_timeout);
|
||||
info!("Healthcheck delay: {}ms", self.general.healthcheck_delay);
|
||||
info!(
|
||||
"Default max server lifetime: {}ms",
|
||||
self.general.server_lifetime
|
||||
);
|
||||
match self.general.tls_certificate.clone() {
|
||||
Some(tls_certificate) => {
|
||||
info!("TLS certificate: {}", tls_certificate);
|
||||
@@ -861,11 +809,6 @@ impl Config {
|
||||
info!("TLS support is disabled");
|
||||
}
|
||||
};
|
||||
info!("Server TLS enabled: {}", self.general.server_tls);
|
||||
info!(
|
||||
"Server TLS certificate verification: {}",
|
||||
self.general.verify_server_certificate
|
||||
);
|
||||
|
||||
for (pool_name, pool_config) in &self.pools {
|
||||
// TODO: Make this output prettier (maybe a table?)
|
||||
@@ -924,26 +867,12 @@ impl Config {
|
||||
pool_name,
|
||||
pool_config.users.len()
|
||||
);
|
||||
info!(
|
||||
"[pool: {}] Max server lifetime: {}",
|
||||
pool_name,
|
||||
match pool_config.server_lifetime {
|
||||
Some(server_lifetime) => format!("{}ms", server_lifetime),
|
||||
None => "default".to_string(),
|
||||
}
|
||||
);
|
||||
|
||||
for user in &pool_config.users {
|
||||
info!(
|
||||
"[pool: {}][user: {}] Pool size: {}",
|
||||
pool_name, user.1.username, user.1.pool_size,
|
||||
);
|
||||
info!(
|
||||
"[pool: {}][user: {}] Minimum pool size: {}",
|
||||
pool_name,
|
||||
user.1.username,
|
||||
user.1.min_pool_size.unwrap_or(0)
|
||||
);
|
||||
info!(
|
||||
"[pool: {}][user: {}] Statement timeout: {}",
|
||||
pool_name, user.1.username, user.1.statement_timeout
|
||||
@@ -957,15 +886,6 @@ impl Config {
|
||||
None => pool_config.pool_mode.to_string(),
|
||||
}
|
||||
);
|
||||
info!(
|
||||
"[pool: {}][user: {}] Max server lifetime: {}",
|
||||
pool_name,
|
||||
user.1.username,
|
||||
match user.1.server_lifetime {
|
||||
Some(server_lifetime) => format!("{}ms", server_lifetime),
|
||||
None => "default".to_string(),
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -976,13 +896,7 @@ impl Config {
|
||||
&& (self.general.auth_query_user.is_none()
|
||||
|| self.general.auth_query_password.is_none())
|
||||
{
|
||||
error!(
|
||||
"If auth_query is specified, \
|
||||
you need to provide a value \
|
||||
for `auth_query_user`, \
|
||||
`auth_query_password`"
|
||||
);
|
||||
|
||||
error!("If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`");
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
|
||||
@@ -990,14 +904,7 @@ impl Config {
|
||||
if pool.auth_query.is_some()
|
||||
&& (pool.auth_query_user.is_none() || pool.auth_query_password.is_none())
|
||||
{
|
||||
error!(
|
||||
"Error in pool {{ {} }}. \
|
||||
If auth_query is specified, you need \
|
||||
to provide a value for `auth_query_user`, \
|
||||
`auth_query_password`",
|
||||
name
|
||||
);
|
||||
|
||||
error!("Error in pool {{ {} }}. If auth_query is specified, you need to provide a value for `auth_query_user`, `auth_query_password`", name);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
|
||||
@@ -1007,13 +914,7 @@ impl Config {
|
||||
|| pool.auth_query_user.is_none())
|
||||
&& user_data.password.is_none()
|
||||
{
|
||||
error!(
|
||||
"Error in pool {{ {} }}. \
|
||||
You have to specify a user password \
|
||||
for every pool if auth_query is not specified",
|
||||
name
|
||||
);
|
||||
|
||||
error!("Error in pool {{ {} }}. You have to specify a user password for every pool if auth_query is not specified", name);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -116,10 +116,7 @@ where
|
||||
|
||||
/// Send the startup packet the server. We're pretending we're a Pg client.
|
||||
/// This tells the server which user we are and what database we want.
|
||||
pub async fn startup<S>(stream: &mut S, user: &str, database: &str) -> Result<(), Error>
|
||||
where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
|
||||
let mut bytes = BytesMut::with_capacity(25);
|
||||
|
||||
bytes.put_i32(196608); // Protocol number
|
||||
@@ -153,21 +150,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ssl_request(stream: &mut TcpStream) -> Result<(), Error> {
|
||||
let mut bytes = BytesMut::with_capacity(12);
|
||||
|
||||
bytes.put_i32(8);
|
||||
bytes.put_i32(80877103);
|
||||
|
||||
match stream.write_all(&bytes).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => Err(Error::SocketError(format!(
|
||||
"Error writing SSLRequest to server socket - Error: {:?}",
|
||||
err
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse the params the server sends as a key/value format.
|
||||
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
|
||||
let mut result = HashMap::new();
|
||||
@@ -523,29 +505,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn write_all_flush<S>(stream: &mut S, buf: &[u8]) -> Result<(), Error>
|
||||
where
|
||||
S: tokio::io::AsyncWrite + std::marker::Unpin,
|
||||
{
|
||||
match stream.write_all(buf).await {
|
||||
Ok(_) => match stream.flush().await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error flushing socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Error writing to socket - Error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a complete message from the socket.
|
||||
pub async fn read_message<S>(stream: &mut S) -> Result<BytesMut, Error>
|
||||
where
|
||||
|
||||
62
src/pool.rs
62
src/pool.rs
@@ -311,34 +311,21 @@ impl ConnectionPool {
|
||||
|
||||
if let Some(apt) = &auth_passthrough {
|
||||
match apt.fetch_hash(&address).await {
|
||||
Ok(ok) => {
|
||||
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read())
|
||||
{
|
||||
if ok != *pool_auth_hash_value {
|
||||
warn!(
|
||||
"Hash is not the same across shards \
|
||||
of the same pool, client auth will \
|
||||
be done using last obtained hash. \
|
||||
Server: {}:{}, Database: {}",
|
||||
server.host, server.port, shard.database,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
debug!("Hash obtained for {:?}", address);
|
||||
|
||||
{
|
||||
let mut pool_auth_hash = pool_auth_hash.write();
|
||||
*pool_auth_hash = Some(ok.clone());
|
||||
}
|
||||
}
|
||||
Err(err) => warn!(
|
||||
"Could not obtain password hashes \
|
||||
using auth_query config, ignoring. \
|
||||
Error: {:?}",
|
||||
err,
|
||||
),
|
||||
}
|
||||
Ok(ok) => {
|
||||
if let Some(ref pool_auth_hash_value) = *(pool_auth_hash.read()) {
|
||||
if ok != *pool_auth_hash_value {
|
||||
warn!("Hash is not the same across shards of the same pool, client auth will \
|
||||
be done using last obtained hash. Server: {}:{}, Database: {}", server.host, server.port, shard.database);
|
||||
}
|
||||
}
|
||||
debug!("Hash obtained for {:?}", address);
|
||||
{
|
||||
let mut pool_auth_hash = pool_auth_hash.write();
|
||||
*pool_auth_hash = Some(ok.clone());
|
||||
}
|
||||
},
|
||||
Err(err) => warn!("Could not obtain password hashes using auth_query config, ignoring. Error: {:?}", err),
|
||||
}
|
||||
}
|
||||
|
||||
let manager = ServerPool::new(
|
||||
@@ -360,23 +347,14 @@ impl ConnectionPool {
|
||||
None => config.general.idle_timeout,
|
||||
};
|
||||
|
||||
let server_lifetime = match user.server_lifetime {
|
||||
Some(server_lifetime) => server_lifetime,
|
||||
None => match pool_config.server_lifetime {
|
||||
Some(server_lifetime) => server_lifetime,
|
||||
None => config.general.server_lifetime,
|
||||
},
|
||||
};
|
||||
|
||||
let pool = Pool::builder()
|
||||
.max_size(user.pool_size)
|
||||
.min_idle(user.min_pool_size)
|
||||
.connection_timeout(std::time::Duration::from_millis(connect_timeout))
|
||||
.idle_timeout(Some(std::time::Duration::from_millis(idle_timeout)))
|
||||
.max_lifetime(Some(std::time::Duration::from_millis(server_lifetime)))
|
||||
.test_on_check_out(false)
|
||||
.build(manager)
|
||||
.await?;
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
pools.push(pool);
|
||||
servers.push(address);
|
||||
@@ -799,6 +777,7 @@ impl ConnectionPool {
|
||||
self.databases.len()
|
||||
}
|
||||
|
||||
/// Retrieve all bans for all servers.
|
||||
pub fn get_bans(&self) -> Vec<(Address, (BanReason, NaiveDateTime))> {
|
||||
let mut bans: Vec<(Address, (BanReason, NaiveDateTime))> = Vec::new();
|
||||
let guard = self.banlist.read();
|
||||
@@ -810,7 +789,7 @@ impl ConnectionPool {
|
||||
return bans;
|
||||
}
|
||||
|
||||
/// Get the address from the host url
|
||||
/// Get the address from the host url.
|
||||
pub fn get_addresses_from_host(&self, host: &str) -> Vec<Address> {
|
||||
let mut addresses = Vec::new();
|
||||
for shard in 0..self.shards() {
|
||||
@@ -849,10 +828,13 @@ impl ConnectionPool {
|
||||
&self.addresses[shard][server]
|
||||
}
|
||||
|
||||
/// Get server settings retrieved at connection setup.
|
||||
pub fn server_info(&self) -> BytesMut {
|
||||
self.server_info.read().clone()
|
||||
}
|
||||
|
||||
/// Calculate how many used connections in the pool
|
||||
/// for the given server.
|
||||
fn busy_connection_count(&self, address: &Address) -> u32 {
|
||||
let state = self.pool_state(address.shard, address.address_index);
|
||||
let idle = state.idle_connections;
|
||||
|
||||
227
src/server.rs
227
src/server.rs
@@ -9,12 +9,13 @@ use std::collections::HashMap;
|
||||
use std::io::Read;
|
||||
use std::sync::Arc;
|
||||
use std::time::SystemTime;
|
||||
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
|
||||
use tokio_rustls::{client::TlsStream, TlsConnector};
|
||||
use tokio::io::{AsyncReadExt, BufReader};
|
||||
use tokio::net::{
|
||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
TcpStream,
|
||||
};
|
||||
|
||||
use crate::config::{get_config, Address, User};
|
||||
use crate::config::{Address, User};
|
||||
use crate::constants::*;
|
||||
use crate::errors::{Error, ServerIdentifier};
|
||||
use crate::messages::*;
|
||||
@@ -22,84 +23,6 @@ use crate::mirrors::MirroringManager;
|
||||
use crate::pool::ClientServerMap;
|
||||
use crate::scram::ScramSha256;
|
||||
use crate::stats::ServerStats;
|
||||
use std::io::Write;
|
||||
|
||||
use pin_project::pin_project;
|
||||
|
||||
#[pin_project(project = SteamInnerProj)]
|
||||
pub enum StreamInner {
|
||||
Plain {
|
||||
#[pin]
|
||||
stream: TcpStream,
|
||||
},
|
||||
Tls {
|
||||
#[pin]
|
||||
stream: TlsStream<TcpStream>,
|
||||
},
|
||||
}
|
||||
|
||||
impl AsyncWrite for StreamInner {
|
||||
fn poll_write(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> std::task::Poll<Result<usize, std::io::Error>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
SteamInnerProj::Tls { stream } => stream.poll_write(cx, buf),
|
||||
SteamInnerProj::Plain { stream } => stream.poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
SteamInnerProj::Tls { stream } => stream.poll_flush(cx),
|
||||
SteamInnerProj::Plain { stream } => stream.poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), std::io::Error>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
SteamInnerProj::Tls { stream } => stream.poll_shutdown(cx),
|
||||
SteamInnerProj::Plain { stream } => stream.poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for StreamInner {
|
||||
fn poll_read(
|
||||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut tokio::io::ReadBuf<'_>,
|
||||
) -> std::task::Poll<std::io::Result<()>> {
|
||||
let this = self.project();
|
||||
match this {
|
||||
SteamInnerProj::Tls { stream } => stream.poll_read(cx, buf),
|
||||
SteamInnerProj::Plain { stream } => stream.poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamInner {
|
||||
pub fn try_write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
|
||||
match self {
|
||||
StreamInner::Tls { stream } => {
|
||||
let r = stream.get_mut();
|
||||
let mut w = r.1.writer();
|
||||
w.write(buf)
|
||||
}
|
||||
StreamInner::Plain { stream } => stream.try_write(buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Server state.
|
||||
pub struct Server {
|
||||
@@ -107,11 +30,15 @@ pub struct Server {
|
||||
/// port, e.g. 5432, and role, e.g. primary or replica.
|
||||
address: Address,
|
||||
|
||||
/// Server TCP connection.
|
||||
stream: BufStream<StreamInner>,
|
||||
/// Buffered read socket.
|
||||
read: BufReader<OwnedReadHalf>,
|
||||
|
||||
/// Unbuffered write socket (our client code buffers).
|
||||
write: OwnedWriteHalf,
|
||||
|
||||
/// Our server response buffer. We buffer data before we give it to the client.
|
||||
buffer: BytesMut,
|
||||
is_async: bool,
|
||||
|
||||
/// Server information the server sent us over on startup.
|
||||
server_info: BytesMut,
|
||||
@@ -172,88 +99,8 @@ impl Server {
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
// TCP timeouts.
|
||||
configure_socket(&stream);
|
||||
|
||||
let config = get_config();
|
||||
|
||||
let mut stream = if config.general.server_tls {
|
||||
// Request a TLS connection
|
||||
ssl_request(&mut stream).await?;
|
||||
|
||||
let response = match stream.read_u8().await {
|
||||
Ok(response) => response as char,
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Server socket error: {:?}",
|
||||
err
|
||||
)))
|
||||
}
|
||||
};
|
||||
|
||||
match response {
|
||||
// Server supports TLS
|
||||
'S' => {
|
||||
debug!("Connecting to server using TLS");
|
||||
|
||||
let mut root_store = RootCertStore::empty();
|
||||
root_store.add_server_trust_anchors(
|
||||
webpki_roots::TLS_SERVER_ROOTS.0.iter().map(|ta| {
|
||||
OwnedTrustAnchor::from_subject_spki_name_constraints(
|
||||
ta.subject,
|
||||
ta.spki,
|
||||
ta.name_constraints,
|
||||
)
|
||||
}),
|
||||
);
|
||||
|
||||
let mut tls_config = rustls::ClientConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_root_certificates(root_store)
|
||||
.with_no_client_auth();
|
||||
|
||||
// Equivalent to sslmode=prefer which is fine most places.
|
||||
// If you want verify-full, change `verify_server_certificate` to true.
|
||||
if !config.general.verify_server_certificate {
|
||||
let mut dangerous = tls_config.dangerous();
|
||||
dangerous.set_certificate_verifier(Arc::new(
|
||||
crate::tls::NoCertificateVerification {},
|
||||
));
|
||||
}
|
||||
|
||||
let connector = TlsConnector::from(Arc::new(tls_config));
|
||||
let stream = match connector
|
||||
.connect(address.host.as_str().try_into().unwrap(), stream)
|
||||
.await
|
||||
{
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
return Err(Error::SocketError(format!("Server TLS error: {:?}", err)))
|
||||
}
|
||||
};
|
||||
|
||||
StreamInner::Tls { stream }
|
||||
}
|
||||
|
||||
// Server does not support TLS
|
||||
'N' => StreamInner::Plain { stream },
|
||||
|
||||
// Something else?
|
||||
m => {
|
||||
return Err(Error::SocketError(format!(
|
||||
"Unknown message: {}",
|
||||
m as char
|
||||
)));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
StreamInner::Plain { stream }
|
||||
};
|
||||
|
||||
// let (read, write) = split(stream);
|
||||
// let (mut read, mut write) = (ReadInner::Plain { stream: read }, WriteInner::Plain { stream: write });
|
||||
|
||||
trace!("Sending StartupMessage");
|
||||
|
||||
// StartupMessage
|
||||
@@ -399,7 +246,7 @@ impl Server {
|
||||
|
||||
let sasl_type = String::from_utf8_lossy(&sasl_auth[..sasl_len - 2]);
|
||||
|
||||
if sasl_type.contains(SCRAM_SHA_256) {
|
||||
if sasl_type == SCRAM_SHA_256 {
|
||||
debug!("Using {}", SCRAM_SHA_256);
|
||||
|
||||
// Generate client message.
|
||||
@@ -422,7 +269,7 @@ impl Server {
|
||||
res.put_i32(sasl_response.len() as i32);
|
||||
res.put(sasl_response);
|
||||
|
||||
write_all_flush(&mut stream, &res).await?;
|
||||
write_all(&mut stream, res).await?;
|
||||
} else {
|
||||
error!("Unsupported SCRAM version: {}", sasl_type);
|
||||
return Err(Error::ServerError);
|
||||
@@ -453,7 +300,7 @@ impl Server {
|
||||
res.put_i32(4 + sasl_response.len() as i32);
|
||||
res.put(sasl_response);
|
||||
|
||||
write_all_flush(&mut stream, &res).await?;
|
||||
write_all(&mut stream, res).await?;
|
||||
}
|
||||
|
||||
SASL_FINAL => {
|
||||
@@ -597,10 +444,14 @@ impl Server {
|
||||
}
|
||||
};
|
||||
|
||||
let (read, write) = stream.into_split();
|
||||
|
||||
let mut server = Server {
|
||||
address: address.clone(),
|
||||
stream: BufStream::new(stream),
|
||||
read: BufReader::new(read),
|
||||
write,
|
||||
buffer: BytesMut::with_capacity(8196),
|
||||
is_async: false,
|
||||
server_info,
|
||||
process_id,
|
||||
secret_key,
|
||||
@@ -666,7 +517,7 @@ impl Server {
|
||||
bytes.put_i32(process_id);
|
||||
bytes.put_i32(secret_key);
|
||||
|
||||
write_all_flush(&mut stream, &bytes).await
|
||||
write_all(&mut stream, bytes).await
|
||||
}
|
||||
|
||||
/// Send messages to the server from the client.
|
||||
@@ -674,7 +525,7 @@ impl Server {
|
||||
self.mirror_send(messages);
|
||||
self.stats().data_sent(messages.len());
|
||||
|
||||
match write_all_flush(&mut self.stream, &messages).await {
|
||||
match write_all_half(&mut self.write, messages).await {
|
||||
Ok(_) => {
|
||||
// Successfully sent to server
|
||||
self.last_activity = SystemTime::now();
|
||||
@@ -688,12 +539,22 @@ impl Server {
|
||||
}
|
||||
}
|
||||
|
||||
/// Switch to async mode, flushing messages as soon
|
||||
/// as we receive them without buffering or waiting for "ReadyForQuery".
|
||||
pub fn switch_async(&mut self, on: bool) {
|
||||
if on {
|
||||
self.is_async = true;
|
||||
} else {
|
||||
self.is_async = false;
|
||||
}
|
||||
}
|
||||
|
||||
/// Receive data from the server in response to a client request.
|
||||
/// This method must be called multiple times while `self.is_data_available()` is true
|
||||
/// in order to receive all data the server has to offer.
|
||||
pub async fn recv(&mut self) -> Result<BytesMut, Error> {
|
||||
loop {
|
||||
let mut message = match read_message(&mut self.stream).await {
|
||||
let mut message = match read_message(&mut self.read).await {
|
||||
Ok(message) => message,
|
||||
Err(err) => {
|
||||
error!("Terminating server because of: {:?}", err);
|
||||
@@ -783,7 +644,10 @@ impl Server {
|
||||
// DataRow
|
||||
'D' => {
|
||||
// More data is available after this message, this is not the end of the reply.
|
||||
self.data_available = true;
|
||||
// If we're async, flush to client now.
|
||||
if !self.is_async {
|
||||
self.data_available = true;
|
||||
}
|
||||
|
||||
// Don't flush yet, the more we buffer, the faster this goes...up to a limit.
|
||||
if self.buffer.len() >= 8196 {
|
||||
@@ -796,7 +660,10 @@ impl Server {
|
||||
|
||||
// CopyOutResponse: copy is starting from the server to the client.
|
||||
'H' => {
|
||||
self.data_available = true;
|
||||
// If we're in async mode, flush now.
|
||||
if !self.is_async {
|
||||
self.data_available = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@@ -816,6 +683,10 @@ impl Server {
|
||||
// Keep buffering until ReadyForQuery shows up.
|
||||
_ => (),
|
||||
};
|
||||
|
||||
if self.is_async {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let bytes = self.buffer.clone();
|
||||
@@ -1086,13 +957,13 @@ impl Drop for Server {
|
||||
// Update statistics
|
||||
self.stats.disconnect();
|
||||
|
||||
let mut bytes = BytesMut::with_capacity(5);
|
||||
let mut bytes = BytesMut::with_capacity(4);
|
||||
bytes.put_u8(b'X');
|
||||
bytes.put_i32(4);
|
||||
|
||||
match self.stream.get_mut().try_write(&bytes) {
|
||||
Ok(5) => (),
|
||||
_ => debug!("Dirty shutdown"),
|
||||
match self.write.try_write(&bytes) {
|
||||
Ok(_) => (),
|
||||
Err(_) => debug!("Dirty shutdown"),
|
||||
};
|
||||
|
||||
// Should not matter.
|
||||
|
||||
23
src/tls.rs
23
src/tls.rs
@@ -4,12 +4,7 @@ use rustls_pemfile::{certs, read_one, Item};
|
||||
use std::iter;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::SystemTime;
|
||||
use tokio_rustls::rustls::{
|
||||
self,
|
||||
client::{ServerCertVerified, ServerCertVerifier},
|
||||
Certificate, PrivateKey, ServerName,
|
||||
};
|
||||
use tokio_rustls::rustls::{self, Certificate, PrivateKey};
|
||||
use tokio_rustls::TlsAcceptor;
|
||||
|
||||
use crate::config::get_config;
|
||||
@@ -69,19 +64,3 @@ impl Tls {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NoCertificateVerification;
|
||||
|
||||
impl ServerCertVerifier for NoCertificateVerification {
|
||||
fn verify_server_cert(
|
||||
&self,
|
||||
_end_entity: &Certificate,
|
||||
_intermediates: &[Certificate],
|
||||
_server_name: &ServerName,
|
||||
_scts: &mut dyn Iterator<Item = &[u8]>,
|
||||
_ocsp_response: &[u8],
|
||||
_now: SystemTime,
|
||||
) -> Result<ServerCertVerified, rustls::Error> {
|
||||
Ok(ServerCertVerified::assertion())
|
||||
}
|
||||
}
|
||||
|
||||
60
tests/python/async_test.py
Normal file
60
tests/python/async_test.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import psycopg2
|
||||
import asyncio
|
||||
import asyncpg
|
||||
|
||||
PGCAT_HOST = "127.0.0.1"
|
||||
PGCAT_PORT = "6432"
|
||||
|
||||
|
||||
def regular_main():
|
||||
# Connect to the PostgreSQL database
|
||||
conn = psycopg2.connect(
|
||||
host=PGCAT_HOST,
|
||||
database="sharded_db",
|
||||
user="sharding_user",
|
||||
password="sharding_user",
|
||||
port=PGCAT_PORT,
|
||||
)
|
||||
|
||||
# Open a cursor to perform database operations
|
||||
cur = conn.cursor()
|
||||
|
||||
# Execute a SQL query
|
||||
cur.execute("SELECT 1")
|
||||
|
||||
# Fetch the results
|
||||
rows = cur.fetchall()
|
||||
|
||||
# Print the results
|
||||
for row in rows:
|
||||
print(row[0])
|
||||
|
||||
# Close the cursor and the database connection
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
async def main():
|
||||
# Connect to the PostgreSQL database
|
||||
conn = await asyncpg.connect(
|
||||
host=PGCAT_HOST,
|
||||
database="sharded_db",
|
||||
user="sharding_user",
|
||||
password="sharding_user",
|
||||
port=PGCAT_PORT,
|
||||
)
|
||||
|
||||
# Execute a SQL query
|
||||
for _ in range(25):
|
||||
rows = await conn.fetch("SELECT 1")
|
||||
|
||||
# Print the results
|
||||
for row in rows:
|
||||
print(row[0])
|
||||
|
||||
# Close the database connection
|
||||
await conn.close()
|
||||
|
||||
|
||||
regular_main()
|
||||
asyncio.run(main())
|
||||
@@ -1,2 +1,11 @@
|
||||
asyncio==3.4.3
|
||||
asyncpg==0.27.0
|
||||
black==23.3.0
|
||||
click==8.1.3
|
||||
mypy-extensions==1.0.0
|
||||
packaging==23.1
|
||||
pathspec==0.11.1
|
||||
platformdirs==3.2.0
|
||||
psutil==5.9.1
|
||||
psycopg2==2.9.3
|
||||
psutil==5.9.1
|
||||
tomli==2.0.1
|
||||
|
||||
@@ -25,7 +25,7 @@ describe "Query Mirroing" do
|
||||
processes.pgcat.shutdown
|
||||
end
|
||||
|
||||
xit "can mirror a query" do
|
||||
it "can mirror a query" do
|
||||
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
|
||||
runs = 15
|
||||
runs.times { conn.async_exec("SELECT 1 + 2") }
|
||||
|
||||
Reference in New Issue
Block a user