Compare commits

...

28 Commits

Author SHA1 Message Date
Kevin Zimmerman
dc649aaee3 simplified write!; generic new functions 2023-07-26 19:53:32 -05:00
Kevin Zimmerman
b4ba3b378c fix formatting 2023-07-26 09:47:58 -05:00
Kevin Zimmerman
81536a0bad make AuthPassthrough generic 2023-07-26 09:44:31 -05:00
Kevin Zimmerman
6eb01e51a0 remove async/spawn in Collector::collect 2023-07-25 19:56:15 -05:00
Kevin Zimmerman
ae3241b634 use Result::map_err and ? in Tls::new 2023-07-25 19:44:04 -05:00
Kevin Zimmerman
33724ea670 simplify TableAccess::run 2023-07-25 19:34:37 -05:00
Kevin Zimmerman
1c26aa3547 simplify format! 2023-07-25 19:24:52 -05:00
Kevin Zimmerman
64eb417125 remove unnecessary allocation 2023-07-25 19:24:04 -05:00
Kevin Zimmerman
22d9d3c90a fix query_logger info! argument order 2023-07-25 15:59:01 -05:00
Kevin Zimmerman
3162d550fd simplify format_duration, reduce String allocs 2023-07-25 15:54:48 -05:00
Kevin Zimmerman
12522562ce fix clippy lints 2023-07-25 15:49:46 -05:00
Lev Kokotov
4cf54a6122 Release 1.1 (#526) 2023-07-25 10:27:04 -07:00
Mostafa Abdelraouf
2a8f3653a6 Fix COPY FROM and add tests (#522)
* Fix COPY FROM and add tests

* E

* fmt
2023-07-20 23:06:01 -07:00
Sebastian Webber
19cb8a3022 add --no-color option to disable colors in the terminal (#518)
add --no-color option to disable colors

this commit adds a new option to disable colors in the terminal and also
moves the logger configuration to a different crate.

Signed-off-by: Sebastian Webber <sebastian@swebber.me>
2023-07-19 21:15:55 -07:00
Sebastian Webber
f85e5bd9e8 add support for multiple log formats (#517)
this commit adds the tracing-subscriber crate and use its formatters to
support multiple log formats.

More details in
https://github.com/postgresml/pgcat/issues/464#issuecomment-1641430299

Signed-off-by: Sebastian Webber <sebastian@swebber.me>
2023-07-18 23:07:13 -07:00
Sebastian Webber
7bdb4e5cd9 Add cmd line parser (#512)
This commit adds the clap library and configures the necessary args to
parse from the command line,  expanding the current option of a single
file and adding support for environment variables.

Signed-off-by: Sebastian Webber <sebastian@swebber.me>
2023-07-18 13:52:40 -07:00
Sebastian Webber
5d87e3781e push and build only in main and tags (#508)
this commit changes the CI behavior to only build and push when something is committed to main or is a new tag.
2023-07-14 10:30:49 -07:00
dependabot[bot]
3e08c6bd8d chore(deps): bump num_cpus from 1.15.0 to 1.16.0 (#507)
Bumps [num_cpus](https://github.com/seanmonstar/num_cpus) from 1.15.0 to 1.16.0.
- [Release notes](https://github.com/seanmonstar/num_cpus/releases)
- [Changelog](https://github.com/seanmonstar/num_cpus/blob/master/CHANGELOG.md)
- [Commits](https://github.com/seanmonstar/num_cpus/compare/v1.15.0...v1.16.0)

---
updated-dependencies:
- dependency-name: num_cpus
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-14 07:58:11 -07:00
Sebastian Webber
15b6db8e4e add "show help" command (#505)
This commit adds a new function to handle notify and use it
in the SHOW HELP command, which displays the available options
in the admin console.

Also, adding Fabrízio as a co-author for all the help with the
protocol and the help to structure this PR.

Signed-off-by: Sebastian Webber <sebastian@swebber.me>
Co-authored-by: Fabrízio de Royes Mello <fabriziomello@gmail.com>
2023-07-13 22:40:04 -07:00
dependabot[bot]
b2e6dfd9bb chore(deps): bump rustls-pemfile from 1.0.2 to 1.0.3 (#504)
Bumps [rustls-pemfile](https://github.com/rustls/pemfile) from 1.0.2 to 1.0.3.
- [Commits](https://github.com/rustls/pemfile/commits)

---
updated-dependencies:
- dependency-name: rustls-pemfile
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-12 21:41:48 -07:00
Mostafa Abdelraouf
3c9565d351 Add support for tcp_user_timeout (#503)
* Add support for tcp_user_timeout

* option

* duration

* Some()

* docs

* fmt, compile
2023-07-12 11:24:30 -07:00
dependabot[bot]
67579c9af4 chore(deps): bump rustls from 0.21.1 to 0.21.5 (#501)
Bumps [rustls](https://github.com/rustls/rustls) from 0.21.1 to 0.21.5.
- [Release notes](https://github.com/rustls/rustls/releases)
- [Commits](https://github.com/rustls/rustls/compare/v/0.21.1...v/0.21.5)

---
updated-dependencies:
- dependency-name: rustls
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-07-12 05:46:31 -07:00
Cluas
cf7f6f35ab docs: fix general.autoreload description (#491)
* docs: fix autoreload description

Signed-off-by: Cluas <Cluas@live.cn>

* docs: add blank line

Signed-off-by: Cluas <Cluas@live.cn>

---------

Signed-off-by: Cluas <Cluas@live.cn>
2023-07-12 05:42:44 -07:00
Voldemarich
7205537b49 [BUG] Fix binding of NULL value parameters in prepared statements (#496)
Fix binding of NULL value parameters in prepared statements

Co-authored-by: anon <anon@non.existent>
2023-07-10 10:35:43 +02:00
Zain Kabani
1ed6e925ed Fixes the default for round robing in General (#488) 2023-06-23 09:15:44 -07:00
Lev Kokotov
4b78af9676 Implement Close for prepared statements (#482)
* Partial support for Close

* Close

* respect config value

* prepared spec

* Hmm

* Print cache size
2023-06-18 23:02:34 -07:00
Lev Kokotov
73500c0c96 Fix build (#481) 2023-06-17 09:09:54 -07:00
Lev Kokotov
b167de5aa3 fmt (#480) 2023-06-17 08:57:33 -07:00
34 changed files with 1369 additions and 784 deletions

View File

@@ -1,6 +1,11 @@
name: Build and Push name: Build and Push
on: push on:
push:
branches:
- main
tags:
- v*
env: env:
registry: ghcr.io registry: ghcr.io

View File

@@ -116,10 +116,10 @@ If we should log client disconnections
### autoreload ### autoreload
``` ```
path: general.autoreload path: general.autoreload
default: 15000 default: 15000 # milliseconds
``` ```
When set to true, PgCat reloads configs if it detects a change in the config file. When set, PgCat automatically reloads its configurations at the specified interval (in milliseconds) if it detects changes in the configuration file. The default interval is 15000 milliseconds or 15 seconds.
### worker_threads ### worker_threads
``` ```
@@ -151,7 +151,13 @@ path: general.tcp_keepalives_interval
default: 5 default: 5
``` ```
Number of seconds between keepalive packets. ### tcp_user_timeout
```
path: general.tcp_user_timeout
default: 10000
```
A linux-only parameters that defines the amount of time in milliseconds that transmitted data may remain unacknowledged or buffered data may remain untransmitted (due to zero window size) before TCP will forcibly disconnect
### tls_certificate ### tls_certificate
``` ```

901
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "pgcat" name = "pgcat"
version = "1.0.2-alpha3" version = "1.1.0"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -46,6 +46,9 @@ trust-dns-resolver = "0.22.0"
tokio-test = "0.4.2" tokio-test = "0.4.2"
serde_json = "1" serde_json = "1"
itertools = "0.10" itertools = "0.10"
clap = { version = "4.3.1", features = ["derive", "env"] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["json"]}
[target.'cfg(not(target_env = "msvc"))'.dependencies] [target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0" jemallocator = "0.5.0"

View File

@@ -63,6 +63,9 @@ tcp_keepalives_interval = 5
# Handle prepared statements. # Handle prepared statements.
prepared_statements = true prepared_statements = true
# Prepared statements server cache size.
prepared_statements_cache_size = 500
# Path to TLS Certificate file to use for TLS connections # Path to TLS Certificate file to use for TLS connections
# tls_certificate = ".circleci/server.cert" # tls_certificate = ".circleci/server.cert"
# Path to TLS private key file to use for TLS connections # Path to TLS private key file to use for TLS connections

View File

@@ -84,6 +84,10 @@ where
shutdown(stream).await shutdown(stream).await
} }
"SHOW" => match query_parts[1].to_ascii_uppercase().as_str() { "SHOW" => match query_parts[1].to_ascii_uppercase().as_str() {
"HELP" => {
trace!("SHOW HELP");
show_help(stream).await
}
"BANS" => { "BANS" => {
trace!("SHOW BANS"); trace!("SHOW BANS");
show_bans(stream).await show_bans(stream).await
@@ -271,6 +275,45 @@ where
write_all_half(stream, &res).await write_all_half(stream, &res).await
} }
/// Show all available options.
async fn show_help<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let mut res = BytesMut::new();
let detail_msg = vec![
"",
"SHOW HELP|CONFIG|DATABASES|POOLS|CLIENTS|SERVERS|USERS|VERSION",
// "SHOW PEERS|PEER_POOLS", // missing PEERS|PEER_POOLS
// "SHOW FDS|SOCKETS|ACTIVE_SOCKETS|LISTS|MEM|STATE", // missing FDS|SOCKETS|ACTIVE_SOCKETS|MEM|STATE
"SHOW LISTS",
// "SHOW DNS_HOSTS|DNS_ZONES", // missing DNS_HOSTS|DNS_ZONES
"SHOW STATS", // missing STATS_TOTALS|STATS_AVERAGES|TOTALS
"SET key = arg",
"RELOAD",
"PAUSE [<db>, <user>]",
"RESUME [<db>, <user>]",
// "DISABLE <db>", // missing
// "ENABLE <db>", // missing
// "RECONNECT [<db>]", missing
// "KILL <db>",
// "SUSPEND",
"SHUTDOWN",
// "WAIT_CLOSE [<db>]", // missing
];
res.put(notify("Console usage", detail_msg.join("\n\t")));
res.put(command_complete("SHOW"));
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, &res).await
}
/// Show shards and replicas. /// Show shards and replicas.
async fn show_databases<T>(stream: &mut T) -> Result<(), Error> async fn show_databases<T>(stream: &mut T) -> Result<(), Error>
where where
@@ -701,6 +744,7 @@ where
("age_seconds", DataType::Numeric), ("age_seconds", DataType::Numeric),
("prepare_cache_hit", DataType::Numeric), ("prepare_cache_hit", DataType::Numeric),
("prepare_cache_miss", DataType::Numeric), ("prepare_cache_miss", DataType::Numeric),
("prepare_cache_size", DataType::Numeric),
]; ];
let new_map = get_server_stats(); let new_map = get_server_stats();
@@ -732,6 +776,10 @@ where
.prepared_miss_count .prepared_miss_count
.load(Ordering::Relaxed) .load(Ordering::Relaxed)
.to_string(), .to_string(),
server
.prepared_cache_size
.load(Ordering::Relaxed)
.to_string(),
]; ];
res.put(data_row(&row)); res.put(data_row(&row));
@@ -752,7 +800,7 @@ async fn pause<T>(stream: &mut T, query: &str) -> Result<(), Error>
where where
T: tokio::io::AsyncWrite + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let parts: Vec<&str> = query.split(",").map(|part| part.trim()).collect(); let parts: Vec<&str> = query.split(',').map(|part| part.trim()).collect();
if parts.len() != 2 { if parts.len() != 2 {
error_response( error_response(
@@ -799,7 +847,7 @@ async fn resume<T>(stream: &mut T, query: &str) -> Result<(), Error>
where where
T: tokio::io::AsyncWrite + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let parts: Vec<&str> = query.split(",").map(|part| part.trim()).collect(); let parts: Vec<&str> = query.split(',').map(|part| part.trim()).collect();
if parts.len() != 2 { if parts.len() != 2 {
error_response( error_response(

View File

@@ -12,7 +12,7 @@ pub struct AuthPassthrough {
impl AuthPassthrough { impl AuthPassthrough {
/// Initializes an AuthPassthrough. /// Initializes an AuthPassthrough.
pub fn new(query: &str, user: &str, password: &str) -> Self { pub fn new<S: ToString>(query: S, user: S, password: S) -> Self {
AuthPassthrough { AuthPassthrough {
password: password.to_string(), password: password.to_string(),
query: query.to_string(), query: query.to_string(),

View File

@@ -123,7 +123,7 @@ pub async fn client_entrypoint(
// Client requested a TLS connection. // Client requested a TLS connection.
Ok((ClientConnectionType::Tls, _)) => { Ok((ClientConnectionType::Tls, _)) => {
// TLS settings are configured, will setup TLS now. // TLS settings are configured, will setup TLS now.
if tls_certificate != None { if tls_certificate.is_some() {
debug!("Accepting TLS request"); debug!("Accepting TLS request");
let mut yes = BytesMut::new(); let mut yes = BytesMut::new();
@@ -431,7 +431,7 @@ where
None => "pgcat", None => "pgcat",
}; };
let client_identifier = ClientIdentifier::new(&application_name, &username, &pool_name); let client_identifier = ClientIdentifier::new(application_name, username, pool_name);
let admin = ["pgcat", "pgbouncer"] let admin = ["pgcat", "pgbouncer"]
.iter() .iter()
@@ -906,6 +906,19 @@ where
return Ok(()); return Ok(());
} }
// Close (F)
'C' => {
if prepared_statements_enabled {
let close: Close = (&message).try_into()?;
if close.is_prepared_statement() && !close.anonymous() {
self.prepared_statements.remove(&close.name);
write_all_flush(&mut self.write, &close_complete()).await?;
continue;
}
}
}
_ => (), _ => (),
} }
@@ -917,16 +930,12 @@ where
} }
// Check on plugin results. // Check on plugin results.
match plugin_output { if let Some(PluginOutput::Deny(error)) = plugin_output {
Some(PluginOutput::Deny(error)) => { self.buffer.clear();
self.buffer.clear(); error_response(&mut self.write, &error).await?;
error_response(&mut self.write, &error).await?; plugin_output = None;
plugin_output = None; continue;
continue; }
}
_ => (),
};
// Get a pool instance referenced by the most up-to-date // Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config // pointer. This ensures we always read the latest config
@@ -1130,7 +1139,17 @@ where
} else { } else {
// The statement is not prepared on the server, so we need to prepare it. // The statement is not prepared on the server, so we need to prepare it.
if server.should_prepare(&statement.name) { if server.should_prepare(&statement.name) {
server.prepare(statement).await?; match server.prepare(statement).await {
Ok(_) => (),
Err(err) => {
pool.ban(
&address,
BanReason::MessageSendFailed,
Some(&self.stats),
);
return Err(err);
}
}
} }
} }
@@ -1190,7 +1209,7 @@ where
// Safe to unwrap because we know this message has a certain length and has the code // Safe to unwrap because we know this message has a certain length and has the code
// This reads the first byte without advancing the internal pointer and mutating the bytes // This reads the first byte without advancing the internal pointer and mutating the bytes
let code = *message.get(0).unwrap() as char; let code = *message.first().unwrap() as char;
trace!("Message: {}", code); trace!("Message: {}", code);
@@ -1237,7 +1256,7 @@ where
// Release server back to the pool if we are in transaction mode. // Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects. // If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode { if self.transaction_mode && !server.in_copy_mode() {
self.stats.idle(); self.stats.idle();
break; break;
@@ -1251,6 +1270,10 @@ where
self.stats.disconnect(); self.stats.disconnect();
self.release(); self.release();
if prepared_statements_enabled {
server.maintain_cache().await?;
}
return Ok(()); return Ok(());
} }
@@ -1300,6 +1323,18 @@ where
// Close the prepared statement. // Close the prepared statement.
'C' => { 'C' => {
if prepared_statements_enabled {
let close: Close = (&message).try_into()?;
if close.is_prepared_statement() && !close.anonymous() {
if let Some(parse) = self.prepared_statements.get(&close.name) {
server.will_close(&parse.generated_name);
} else {
// A prepared statement slipped through? Not impossible, since we don't support PREPARE yet.
}
}
}
self.buffer.put(&message[..]); self.buffer.put(&message[..]);
} }
@@ -1334,7 +1369,7 @@ where
self.buffer.put(&message[..]); self.buffer.put(&message[..]);
let first_message_code = (*self.buffer.get(0).unwrap_or(&0)) as char; let first_message_code = (*self.buffer.first().unwrap_or(&0)) as char;
// Almost certainly true // Almost certainly true
if first_message_code == 'P' && !prepared_statements_enabled { if first_message_code == 'P' && !prepared_statements_enabled {
@@ -1368,7 +1403,7 @@ where
// Release server back to the pool if we are in transaction mode. // Release server back to the pool if we are in transaction mode.
// If we are in session mode, we keep the server until the client disconnects. // If we are in session mode, we keep the server until the client disconnects.
if self.transaction_mode { if self.transaction_mode && !server.in_copy_mode() {
break; break;
} }
} }
@@ -1433,7 +1468,13 @@ where
// The server is no longer bound to us, we can't cancel it's queries anymore. // The server is no longer bound to us, we can't cancel it's queries anymore.
debug!("Releasing server back into the pool"); debug!("Releasing server back into the pool");
server.checkin_cleanup().await?; server.checkin_cleanup().await?;
if prepared_statements_enabled {
server.maintain_cache().await?;
}
server.stats().idle(); server.stats().idle();
self.connected_to_server = false; self.connected_to_server = false;

36
src/cmd_args.rs Normal file
View File

@@ -0,0 +1,36 @@
use clap::{Parser, ValueEnum};
use tracing::Level;
/// PgCat: Nextgen PostgreSQL Pooler
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
#[arg(default_value_t = String::from("pgcat.toml"), env)]
pub config_file: String,
#[arg(short, long, default_value_t = tracing::Level::INFO, env)]
pub log_level: Level,
#[clap(short='F', long, value_enum, default_value_t=LogFormat::Text, env)]
pub log_format: LogFormat,
#[arg(
short,
long,
default_value_t = false,
env,
help = "disable colors in the log output"
)]
pub no_color: bool,
}
pub fn parse() -> Args {
Args::parse()
}
#[derive(ValueEnum, Clone, Debug)]
pub enum LogFormat {
Text,
Structured,
Debug,
}

View File

@@ -217,19 +217,15 @@ impl Default for User {
impl User { impl User {
fn validate(&self) -> Result<(), Error> { fn validate(&self) -> Result<(), Error> {
match self.min_pool_size { if let Some(min_pool_size) = self.min_pool_size {
Some(min_pool_size) => { if min_pool_size > self.pool_size {
if min_pool_size > self.pool_size { error!(
error!( "min_pool_size of {} cannot be larger than pool_size of {}",
"min_pool_size of {} cannot be larger than pool_size of {}", min_pool_size, self.pool_size
min_pool_size, self.pool_size );
); return Err(Error::BadConfig);
return Err(Error::BadConfig);
}
} }
}
None => (),
};
Ok(()) Ok(())
} }
@@ -261,6 +257,8 @@ pub struct General {
pub tcp_keepalives_count: u32, pub tcp_keepalives_count: u32,
#[serde(default = "General::default_tcp_keepalives_interval")] #[serde(default = "General::default_tcp_keepalives_interval")]
pub tcp_keepalives_interval: u64, pub tcp_keepalives_interval: u64,
#[serde(default = "General::default_tcp_user_timeout")]
pub tcp_user_timeout: u64,
#[serde(default)] // False #[serde(default)] // False
pub log_client_connections: bool, pub log_client_connections: bool,
@@ -323,6 +321,9 @@ pub struct General {
#[serde(default)] #[serde(default)]
pub prepared_statements: bool, pub prepared_statements: bool,
#[serde(default = "General::default_prepared_statements_cache_size")]
pub prepared_statements_cache_size: usize,
} }
impl General { impl General {
@@ -357,6 +358,10 @@ impl General {
5 // 5 seconds 5 // 5 seconds
} }
pub fn default_tcp_user_timeout() -> u64 {
10000 // 10000 milliseconds
}
pub fn default_idle_timeout() -> u64 { pub fn default_idle_timeout() -> u64 {
600000 // 10 minutes 600000 // 10 minutes
} }
@@ -400,6 +405,10 @@ impl General {
pub fn default_server_round_robin() -> bool { pub fn default_server_round_robin() -> bool {
true true
} }
pub fn default_prepared_statements_cache_size() -> usize {
500
}
} }
impl Default for General { impl Default for General {
@@ -420,6 +429,7 @@ impl Default for General {
tcp_keepalives_idle: Self::default_tcp_keepalives_idle(), tcp_keepalives_idle: Self::default_tcp_keepalives_idle(),
tcp_keepalives_count: Self::default_tcp_keepalives_count(), tcp_keepalives_count: Self::default_tcp_keepalives_count(),
tcp_keepalives_interval: Self::default_tcp_keepalives_interval(), tcp_keepalives_interval: Self::default_tcp_keepalives_interval(),
tcp_user_timeout: Self::default_tcp_user_timeout(),
log_client_connections: false, log_client_connections: false,
log_client_disconnections: false, log_client_disconnections: false,
autoreload: None, autoreload: None,
@@ -435,9 +445,10 @@ impl Default for General {
auth_query_user: None, auth_query_user: None,
auth_query_password: None, auth_query_password: None,
server_lifetime: Self::default_server_lifetime(), server_lifetime: Self::default_server_lifetime(),
server_round_robin: false, server_round_robin: Self::default_server_round_robin(),
validate_config: true, validate_config: true,
prepared_statements: false, prepared_statements: false,
prepared_statements_cache_size: 500,
} }
} }
} }
@@ -616,9 +627,9 @@ impl Pool {
Some(key) => { Some(key) => {
// No quotes in the key so we don't have to compare quoted // No quotes in the key so we don't have to compare quoted
// to unquoted idents. // to unquoted idents.
let key = key.replace("\"", ""); let key = key.replace('\"', "");
if key.split(".").count() != 2 { if key.split('.').count() != 2 {
error!( error!(
"automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`", "automatic_sharding_key '{}' must be fully qualified, e.g. t.{}`",
key, key key, key
@@ -631,7 +642,7 @@ impl Pool {
None => None, None => None,
}; };
for (_, user) in &self.users { for user in self.users.values() {
user.validate()?; user.validate()?;
} }
@@ -803,8 +814,8 @@ pub struct Query {
impl Query { impl Query {
pub fn substitute(&mut self, db: &str, user: &str) { pub fn substitute(&mut self, db: &str, user: &str) {
for col in self.result.iter_mut() { for col in self.result.iter_mut() {
for i in 0..col.len() { for c in col {
col[i] = col[i].replace("${USER}", user).replace("${DATABASE}", db); *c = c.replace("${USER}", user).replace("${DATABASE}", db);
} }
} }
} }
@@ -914,8 +925,8 @@ impl From<&Config> for std::collections::HashMap<String, String> {
( (
format!("pools.{:?}.users", pool_name), format!("pools.{:?}.users", pool_name),
pool.users pool.users
.iter() .values()
.map(|(_username, user)| &user.username) .map(|user| &user.username)
.cloned() .cloned()
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join(", "), .join(", "),
@@ -1000,13 +1011,9 @@ impl Config {
Some(tls_certificate) => { Some(tls_certificate) => {
info!("TLS certificate: {}", tls_certificate); info!("TLS certificate: {}", tls_certificate);
match self.general.tls_private_key.clone() { if let Some(tls_private_key) = self.general.tls_private_key.clone() {
Some(tls_private_key) => { info!("TLS private key: {}", tls_private_key);
info!("TLS private key: {}", tls_private_key); info!("TLS support is enabled");
info!("TLS support is enabled");
}
None => (),
} }
} }
@@ -1020,6 +1027,12 @@ impl Config {
self.general.verify_server_certificate self.general.verify_server_certificate
); );
info!("Prepared statements: {}", self.general.prepared_statements); info!("Prepared statements: {}", self.general.prepared_statements);
if self.general.prepared_statements {
info!(
"Prepared statements server cache size: {}",
self.general.prepared_statements_cache_size
);
}
info!( info!(
"Plugins: {}", "Plugins: {}",
match self.plugins { match self.plugins {
@@ -1035,8 +1048,8 @@ impl Config {
pool_name, pool_name,
pool_config pool_config
.users .users
.iter() .values()
.map(|(_, user_cfg)| user_cfg.pool_size) .map(|user_cfg| user_cfg.pool_size)
.sum::<u32>() .sum::<u32>()
.to_string() .to_string()
); );
@@ -1193,35 +1206,32 @@ impl Config {
} }
// Validate TLS! // Validate TLS!
match self.general.tls_certificate.clone() { if let Some(tls_certificate) = self.general.tls_certificate.clone() {
Some(tls_certificate) => { match load_certs(Path::new(&tls_certificate)) {
match load_certs(Path::new(&tls_certificate)) { Ok(_) => {
Ok(_) => { // Cert is okay, but what about the private key?
// Cert is okay, but what about the private key? match self.general.tls_private_key.clone() {
match self.general.tls_private_key.clone() { Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) {
Some(tls_private_key) => match load_keys(Path::new(&tls_private_key)) { Ok(_) => (),
Ok(_) => (), Err(err) => {
Err(err) => { error!("tls_private_key is incorrectly configured: {:?}", err);
error!("tls_private_key is incorrectly configured: {:?}", err);
return Err(Error::BadConfig);
}
},
None => {
error!("tls_certificate is set, but the tls_private_key is not");
return Err(Error::BadConfig); return Err(Error::BadConfig);
} }
}; },
}
Err(err) => { None => {
error!("tls_certificate is incorrectly configured: {:?}", err); error!("tls_certificate is set, but the tls_private_key is not");
return Err(Error::BadConfig); return Err(Error::BadConfig);
} }
};
}
Err(err) => {
error!("tls_certificate is incorrectly configured: {:?}", err);
return Err(Error::BadConfig);
} }
} }
None => (), }
};
for pool in self.pools.values_mut() { for pool in self.pools.values_mut() {
pool.validate()?; pool.validate()?;
@@ -1239,13 +1249,15 @@ pub fn get_config() -> Config {
} }
pub fn get_idle_client_in_transaction_timeout() -> u64 { pub fn get_idle_client_in_transaction_timeout() -> u64 {
(*(*CONFIG.load())) CONFIG.load().general.idle_client_in_transaction_timeout
.general
.idle_client_in_transaction_timeout
} }
pub fn get_prepared_statements() -> bool { pub fn get_prepared_statements() -> bool {
(*(*CONFIG.load())).general.prepared_statements CONFIG.load().general.prepared_statements
}
pub fn get_prepared_statements_cache_size() -> usize {
CONFIG.load().general.prepared_statements_cache_size
} }
/// Parse the configuration file located at the path. /// Parse the configuration file located at the path.

View File

@@ -37,11 +37,11 @@ pub struct ClientIdentifier {
} }
impl ClientIdentifier { impl ClientIdentifier {
pub fn new(application_name: &str, username: &str, pool_name: &str) -> ClientIdentifier { pub fn new<S: ToString>(application_name: S, username: S, pool_name: S) -> ClientIdentifier {
ClientIdentifier { ClientIdentifier {
application_name: application_name.into(), application_name: application_name.to_string(),
username: username.into(), username: username.to_string(),
pool_name: pool_name.into(), pool_name: pool_name.to_string(),
} }
} }
} }
@@ -63,10 +63,10 @@ pub struct ServerIdentifier {
} }
impl ServerIdentifier { impl ServerIdentifier {
pub fn new(username: &str, database: &str) -> ServerIdentifier { pub fn new<S: ToString>(username: S, database: S) -> ServerIdentifier {
ServerIdentifier { ServerIdentifier {
username: username.into(), username: username.to_string(),
database: database.into(), database: database.to_string(),
} }
} }
} }
@@ -84,41 +84,36 @@ impl std::fmt::Display for ServerIdentifier {
impl std::fmt::Display for Error { impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match &self { match &self {
&Error::ClientSocketError(error, client_identifier) => write!( Error::ClientSocketError(error, client_identifier) => {
f, write!(f, "Error reading {error} from client {client_identifier}",)
"Error reading {} from client {}",
error, client_identifier
),
&Error::ClientGeneralError(error, client_identifier) => {
write!(f, "{} {}", error, client_identifier)
} }
&Error::ClientAuthImpossible(username) => write!( Error::ClientGeneralError(error, client_identifier) => {
write!(f, "{error} {client_identifier}")
}
Error::ClientAuthImpossible(username) => write!(
f, f,
"Client auth not possible, \ "Client auth not possible, \
no cleartext password set for username: {} \ no cleartext password set for username: {username} \
in config and auth passthrough (query_auth) \ in config and auth passthrough (query_auth) \
is not set up.", is not set up."
username
), ),
&Error::ClientAuthPassthroughError(error, client_identifier) => write!( Error::ClientAuthPassthroughError(error, client_identifier) => write!(
f, f,
"No cleartext password set, \ "No cleartext password set, \
and no auth passthrough could not \ and no auth passthrough could not \
obtain the hash from server for {}, \ obtain the hash from server for {client_identifier}, \
the error was: {}", the error was: {error}",
client_identifier, error
), ),
&Error::ServerStartupError(error, server_identifier) => write!( Error::ServerStartupError(error, server_identifier) => write!(
f, f,
"Error reading {} on server startup {}", "Error reading {error} on server startup {server_identifier}",
error, server_identifier,
), ),
&Error::ServerAuthError(error, server_identifier) => { Error::ServerAuthError(error, server_identifier) => {
write!(f, "{} for {}", error, server_identifier,) write!(f, "{error} for {server_identifier}")
} }
// The rest can use Debug. // The rest can use Debug.
err => write!(f, "{:?}", err), err => write!(f, "{err:?}"),
} }
} }
} }

View File

@@ -1,13 +1,14 @@
pub mod admin; pub mod admin;
pub mod auth_passthrough; pub mod auth_passthrough;
pub mod client; pub mod client;
pub mod cmd_args;
pub mod config; pub mod config;
pub mod constants; pub mod constants;
pub mod dns_cache; pub mod dns_cache;
pub mod errors; pub mod errors;
pub mod logger;
pub mod messages; pub mod messages;
pub mod mirrors; pub mod mirrors;
pub mod multi_logger;
pub mod plugins; pub mod plugins;
pub mod pool; pub mod pool;
pub mod prometheus; pub mod prometheus;
@@ -24,18 +25,11 @@ pub mod tls;
/// ///
/// * `duration` - A duration of time /// * `duration` - A duration of time
pub fn format_duration(duration: &chrono::Duration) -> String { pub fn format_duration(duration: &chrono::Duration) -> String {
let milliseconds = format!("{:0>3}", duration.num_milliseconds() % 1000); let milliseconds = duration.num_milliseconds() % 1000;
let seconds = duration.num_seconds() % 60;
let minutes = duration.num_minutes() % 60;
let hours = duration.num_hours() % 24;
let days = duration.num_days();
let seconds = format!("{:0>2}", duration.num_seconds() % 60); format!("{days}d {hours:0>2}:{minutes:0>2}:{seconds:0>2}.{milliseconds:0>3}")
let minutes = format!("{:0>2}", duration.num_minutes() % 60);
let hours = format!("{:0>2}", duration.num_hours() % 24);
let days = duration.num_days().to_string();
format!(
"{}d {}:{}:{}.{}",
days, hours, minutes, seconds, milliseconds
)
} }

14
src/logger.rs Normal file
View File

@@ -0,0 +1,14 @@
use crate::cmd_args::{Args, LogFormat};
use tracing_subscriber;
pub fn init(args: &Args) {
let trace_sub = tracing_subscriber::fmt()
.with_max_level(args.log_level)
.with_ansi(!args.no_color);
match args.log_format {
LogFormat::Structured => trace_sub.json().init(),
LogFormat::Debug => trace_sub.pretty().init(),
_ => trace_sub.init(),
};
}

View File

@@ -61,15 +61,18 @@ use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::broadcast; use tokio::sync::broadcast;
use pgcat::cmd_args;
use pgcat::config::{get_config, reload_config, VERSION}; use pgcat::config::{get_config, reload_config, VERSION};
use pgcat::dns_cache; use pgcat::dns_cache;
use pgcat::logger;
use pgcat::messages::configure_socket; use pgcat::messages::configure_socket;
use pgcat::pool::{ClientServerMap, ConnectionPool}; use pgcat::pool::{ClientServerMap, ConnectionPool};
use pgcat::prometheus::start_metric_server; use pgcat::prometheus::start_metric_server;
use pgcat::stats::{Collector, Reporter, REPORTER}; use pgcat::stats::{Collector, Reporter, REPORTER};
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {
pgcat::multi_logger::MultiLogger::init().unwrap(); let args = cmd_args::parse();
logger::init(&args);
info!("Welcome to PgCat! Meow. (Version {})", VERSION); info!("Welcome to PgCat! Meow. (Version {})", VERSION);
@@ -78,20 +81,12 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
std::process::exit(exitcode::CONFIG); std::process::exit(exitcode::CONFIG);
} }
let args = std::env::args().collect::<Vec<String>>();
let config_file = if args.len() == 2 {
args[1].to_string()
} else {
String::from("pgcat.toml")
};
// Create a transient runtime for loading the config for the first time. // Create a transient runtime for loading the config for the first time.
{ {
let runtime = Builder::new_multi_thread().worker_threads(1).build()?; let runtime = Builder::new_multi_thread().worker_threads(1).build()?;
runtime.block_on(async { runtime.block_on(async {
match pgcat::config::parse(&config_file).await { match pgcat::config::parse(args.config_file.as_str()).await {
Ok(_) => (), Ok(_) => (),
Err(err) => { Err(err) => {
error!("Config parse error: {:?}", err); error!("Config parse error: {:?}", err);
@@ -165,10 +160,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
} }
}; };
tokio::task::spawn(async move { Collector::collect();
let mut stats_collector = Collector::default();
stats_collector.collect().await;
});
info!("Config autoreloader: {}", match config.general.autoreload { info!("Config autoreloader: {}", match config.general.autoreload {
Some(interval) => format!("{} ms", interval), Some(interval) => format!("{} ms", interval),

View File

@@ -1,7 +1,7 @@
/// Helper functions to send one-off protocol messages /// Helper functions to send one-off protocol messages
/// and handle TcpStream (TCP socket). /// and handle TcpStream (TCP socket).
use bytes::{Buf, BufMut, BytesMut}; use bytes::{Buf, BufMut, BytesMut};
use log::error; use log::{debug, error};
use md5::{Digest, Md5}; use md5::{Digest, Md5};
use socket2::{SockRef, TcpKeepalive}; use socket2::{SockRef, TcpKeepalive};
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWriteExt};
@@ -156,12 +156,10 @@ where
match stream.write_all(&startup).await { match stream.write_all(&startup).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => Err(Error::SocketError(format!(
return Err(Error::SocketError(format!( "Error writing startup to server socket - Error: {:?}",
"Error writing startup to server socket - Error: {:?}", err
err ))),
)))
}
} }
} }
@@ -237,8 +235,8 @@ pub fn md5_hash_password(user: &str, password: &str, salt: &[u8]) -> Vec<u8> {
let mut md5 = Md5::new(); let mut md5 = Md5::new();
// First pass // First pass
md5.update(&password.as_bytes()); md5.update(password.as_bytes());
md5.update(&user.as_bytes()); md5.update(user.as_bytes());
let output = md5.finalize_reset(); let output = md5.finalize_reset();
@@ -274,7 +272,7 @@ where
{ {
let password = md5_hash_password(user, password, salt); let password = md5_hash_password(user, password, salt);
let mut message = BytesMut::with_capacity(password.len() as usize + 5); let mut message = BytesMut::with_capacity(password.len() + 5);
message.put_u8(b'p'); message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4); message.put_i32(password.len() as i32 + 4);
@@ -288,7 +286,7 @@ where
S: tokio::io::AsyncWrite + std::marker::Unpin, S: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let password = md5_hash_second_pass(hash, salt); let password = md5_hash_second_pass(hash, salt);
let mut message = BytesMut::with_capacity(password.len() as usize + 5); let mut message = BytesMut::with_capacity(password.len() + 5);
message.put_u8(b'p'); message.put_u8(b'p');
message.put_i32(password.len() as i32 + 4); message.put_i32(password.len() as i32 + 4);
@@ -509,7 +507,7 @@ pub fn data_row_nullable(row: &Vec<Option<String>>) -> BytesMut {
data_row.put_i32(column.len() as i32); data_row.put_i32(column.len() as i32);
data_row.put_slice(column); data_row.put_slice(column);
} else { } else {
data_row.put_i32(-1 as i32); data_row.put_i32(-1_i32);
} }
} }
@@ -530,6 +528,26 @@ pub fn command_complete(command: &str) -> BytesMut {
res res
} }
/// Create a notify message.
pub fn notify(message: &str, details: String) -> BytesMut {
let mut notify_cmd = BytesMut::new();
notify_cmd.put_slice("SNOTICE\0".as_bytes());
notify_cmd.put_slice("C00000\0".as_bytes());
notify_cmd.put_slice(format!("M{}\0", message).as_bytes());
notify_cmd.put_slice(format!("D{}\0", details).as_bytes());
// this extra byte says that is the end of the package
notify_cmd.put_u8(0);
let mut res = BytesMut::new();
res.put_u8(b'N');
res.put_i32(notify_cmd.len() as i32 + 4);
res.put(notify_cmd);
res
}
pub fn flush() -> BytesMut { pub fn flush() -> BytesMut {
let mut bytes = BytesMut::new(); let mut bytes = BytesMut::new();
bytes.put_u8(b'H'); bytes.put_u8(b'H');
@@ -544,12 +562,10 @@ where
{ {
match stream.write_all(&buf).await { match stream.write_all(&buf).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => Err(Error::SocketError(format!(
return Err(Error::SocketError(format!( "Error writing to socket - Error: {:?}",
"Error writing to socket - Error: {:?}", err
err ))),
)))
}
} }
} }
@@ -560,12 +576,10 @@ where
{ {
match stream.write_all(buf).await { match stream.write_all(buf).await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => Err(Error::SocketError(format!(
return Err(Error::SocketError(format!( "Error writing to socket - Error: {:?}",
"Error writing to socket - Error: {:?}", err
err ))),
)))
}
} }
} }
@@ -576,19 +590,15 @@ where
match stream.write_all(buf).await { match stream.write_all(buf).await {
Ok(_) => match stream.flush().await { Ok(_) => match stream.flush().await {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(err) => { Err(err) => Err(Error::SocketError(format!(
return Err(Error::SocketError(format!( "Error flushing socket - Error: {:?}",
"Error flushing socket - Error: {:?}",
err
)))
}
},
Err(err) => {
return Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err err
))) ))),
} },
Err(err) => Err(Error::SocketError(format!(
"Error writing to socket - Error: {:?}",
err
))),
} }
} }
@@ -669,6 +679,13 @@ pub fn configure_socket(stream: &TcpStream) {
let sock_ref = SockRef::from(stream); let sock_ref = SockRef::from(stream);
let conf = get_config(); let conf = get_config();
#[cfg(target_os = "linux")]
match sock_ref.set_tcp_user_timeout(Some(Duration::from_millis(conf.general.tcp_user_timeout)))
{
Ok(_) => (),
Err(err) => error!("Could not configure tcp_user_timeout for socket: {}", err),
}
match sock_ref.set_keepalive(true) { match sock_ref.set_keepalive(true) {
Ok(_) => { Ok(_) => {
match sock_ref.set_tcp_keepalive( match sock_ref.set_tcp_keepalive(
@@ -678,7 +695,7 @@ pub fn configure_socket(stream: &TcpStream) {
.with_time(Duration::from_secs(conf.general.tcp_keepalives_idle)), .with_time(Duration::from_secs(conf.general.tcp_keepalives_idle)),
) { ) {
Ok(_) => (), Ok(_) => (),
Err(err) => error!("Could not configure socket: {}", err), Err(err) => error!("Could not configure tcp_keepalive for socket: {}", err),
} }
} }
Err(err) => error!("Could not configure socket: {}", err), Err(err) => error!("Could not configure socket: {}", err),
@@ -696,7 +713,7 @@ impl BytesMutReader for Cursor<&BytesMut> {
let mut buf = vec![]; let mut buf = vec![];
match self.read_until(b'\0', &mut buf) { match self.read_until(b'\0', &mut buf) {
Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()), Ok(_) => Ok(String::from_utf8_lossy(&buf[..buf.len() - 1]).to_string()),
Err(err) => return Err(Error::ParseBytesError(err.to_string())), Err(err) => Err(Error::ParseBytesError(err.to_string())),
} }
} }
} }
@@ -832,10 +849,21 @@ impl TryFrom<&BytesMut> for Bind {
for _ in 0..num_param_values { for _ in 0..num_param_values {
let param_len = cursor.get_i32(); let param_len = cursor.get_i32();
let mut param = BytesMut::with_capacity(param_len as usize); // There is special occasion when the parameter is NULL
param.resize(param_len as usize, b'0'); // In that case, param length is defined as -1
cursor.copy_to_slice(&mut param); // So if the passed parameter len is over 0
param_values.push((param_len, param)); if param_len > 0 {
let mut param = BytesMut::with_capacity(param_len as usize);
param.resize(param_len as usize, b'0');
cursor.copy_to_slice(&mut param);
// we push and the length and the parameter into vector
param_values.push((param_len, param));
} else {
// otherwise we push a tuple with -1 and 0-len BytesMut
// which means that after encountering -1 postgres proceeds
// to processing another parameter
param_values.push((param_len, BytesMut::new()));
}
} }
let num_result_column_format_codes = cursor.get_i16(); let num_result_column_format_codes = cursor.get_i16();
@@ -976,6 +1004,84 @@ impl Describe {
} }
} }
/// Close (F) message.
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
#[derive(Clone, Debug)]
pub struct Close {
code: char,
#[allow(dead_code)]
len: i32,
close_type: char,
pub name: String,
}
impl TryFrom<&BytesMut> for Close {
type Error = Error;
fn try_from(bytes: &BytesMut) -> Result<Close, Error> {
let mut cursor = Cursor::new(bytes);
let code = cursor.get_u8() as char;
let len = cursor.get_i32();
let close_type = cursor.get_u8() as char;
let name = cursor.read_string()?;
Ok(Close {
code,
len,
close_type,
name,
})
}
}
impl TryFrom<Close> for BytesMut {
type Error = Error;
fn try_from(close: Close) -> Result<BytesMut, Error> {
debug!("Close: {:?}", close);
let mut bytes = BytesMut::new();
let name_binding = CString::new(close.name)?;
let name = name_binding.as_bytes_with_nul();
let len = 4 + 1 + name.len();
bytes.put_u8(close.code as u8);
bytes.put_i32(len as i32);
bytes.put_u8(close.close_type as u8);
bytes.put_slice(name);
Ok(bytes)
}
}
impl Close {
pub fn new(name: &str) -> Close {
let name = name.to_string();
Close {
code: 'C',
len: 4 + 1 + name.len() as i32 + 1, // will be recalculated
close_type: 'S',
name,
}
}
pub fn is_prepared_statement(&self) -> bool {
self.close_type == 'S'
}
pub fn anonymous(&self) -> bool {
self.name.is_empty()
}
}
pub fn close_complete() -> BytesMut {
let mut bytes = BytesMut::new();
bytes.put_u8(b'3');
bytes.put_i32(4);
bytes
}
pub fn prepared_statement_name() -> String { pub fn prepared_statement_name() -> String {
format!( format!(
"P_{}", "P_{}",

View File

@@ -142,12 +142,12 @@ impl MirroringManager {
}); });
Self { Self {
byte_senders: byte_senders, byte_senders,
disconnect_senders: exit_senders, disconnect_senders: exit_senders,
} }
} }
pub fn send(self: &mut Self, bytes: &BytesMut) { pub fn send(&mut self, bytes: &BytesMut) {
// We want to avoid performing an allocation if we won't be able to send the message // We want to avoid performing an allocation if we won't be able to send the message
// There is a possibility of a race here where we check the capacity and then the channel is // There is a possibility of a race here where we check the capacity and then the channel is
// closed or the capacity is reduced to 0, but mirroring is best effort anyway // closed or the capacity is reduced to 0, but mirroring is best effort anyway
@@ -169,7 +169,7 @@ impl MirroringManager {
}); });
} }
pub fn disconnect(self: &mut Self) { pub fn disconnect(&mut self) {
self.disconnect_senders self.disconnect_senders
.iter_mut() .iter_mut()
.for_each(|sender| match sender.try_send(()) { .for_each(|sender| match sender.try_send(()) {

View File

@@ -1,80 +0,0 @@
use log::{Level, Log, Metadata, Record, SetLoggerError};
// This is a special kind of logger that allows sending logs to different
// targets depending on the log level.
//
// By default, if nothing is set, it acts as a regular env_log logger,
// it sends everything to standard error.
//
// If the Env variable `STDOUT_LOG` is defined, it will be used for
// configuring the standard out logger.
//
// The behavior is:
// - If it is an error, the message is written to standard error.
// - If it is not, and it matches the log level of the standard output logger (`STDOUT_LOG` env var), it will be send to standard output.
// - If the above is not true, it is sent to the stderr logger that will log it or not depending on the value
// of the RUST_LOG env var.
//
// So to summarize, if no `STDOUT_LOG` env var is present, the logger is the default logger. If `STDOUT_LOG` is set, everything
// but errors, that matches the log level set in the `STDOUT_LOG` env var is sent to stdout. You can have also some esoteric configuration
// where you set `RUST_LOG=debug` and `STDOUT_LOG=info`, in here, errors will go to stderr, warns and infos to stdout and debugs to stderr.
//
pub struct MultiLogger {
stderr_logger: env_logger::Logger,
stdout_logger: env_logger::Logger,
}
impl MultiLogger {
fn new() -> Self {
let stderr_logger = env_logger::builder().format_timestamp_micros().build();
let stdout_logger = env_logger::Builder::from_env("STDOUT_LOG")
.format_timestamp_micros()
.target(env_logger::Target::Stdout)
.build();
Self {
stderr_logger,
stdout_logger,
}
}
pub fn init() -> Result<(), SetLoggerError> {
let logger = Self::new();
log::set_max_level(logger.stderr_logger.filter());
log::set_boxed_logger(Box::new(logger))
}
}
impl Log for MultiLogger {
fn enabled(&self, metadata: &Metadata) -> bool {
self.stderr_logger.enabled(metadata) && self.stdout_logger.enabled(metadata)
}
fn log(&self, record: &Record) {
if record.level() == Level::Error {
self.stderr_logger.log(record);
} else {
if self.stdout_logger.matches(record) {
self.stdout_logger.log(record);
} else {
self.stderr_logger.log(record);
}
}
}
fn flush(&self) {
self.stderr_logger.flush();
self.stdout_logger.flush();
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_init() {
MultiLogger::init().unwrap();
}
}

View File

@@ -92,7 +92,7 @@ impl<'a> Plugin for Intercept<'a> {
.map(|s| { .map(|s| {
let s = s.as_str().to_string(); let s = s.as_str().to_string();
if s == "" { if s.is_empty() {
None None
} else { } else {
Some(s) Some(s)

View File

@@ -30,6 +30,7 @@ pub enum PluginOutput {
Intercept(BytesMut), Intercept(BytesMut),
} }
#[allow(clippy::ptr_arg)]
#[async_trait] #[async_trait]
pub trait Plugin { pub trait Plugin {
// Run before the query is sent to the server. // Run before the query is sent to the server.

View File

@@ -20,7 +20,7 @@ impl<'a> Prewarmer<'a> {
self.server.address(), self.server.address(),
query query
); );
self.server.query(&query).await?; self.server.query(query).await?;
} }
Ok(()) Ok(())

View File

@@ -31,7 +31,7 @@ impl<'a> Plugin for QueryLogger<'a> {
.map(|q| q.to_string()) .map(|q| q.to_string())
.collect::<Vec<String>>() .collect::<Vec<String>>()
.join("; "); .join("; ");
info!("[pool: {}][user: {}] {}", self.user, self.db, query); info!("[pool: {}][user: {}] {}", self.db, self.user, query);
Ok(PluginOutput::Allow) Ok(PluginOutput::Allow)
} }

View File

@@ -30,27 +30,22 @@ impl<'a> Plugin for TableAccess<'a> {
return Ok(PluginOutput::Allow); return Ok(PluginOutput::Allow);
} }
let mut found = None; let control_flow = visit_relations(ast, |relation| {
visit_relations(ast, |relation| {
let relation = relation.to_string(); let relation = relation.to_string();
let parts = relation.split(".").collect::<Vec<&str>>(); let table_name = relation.split('.').last().unwrap().to_string();
let table_name = parts.last().unwrap();
if self.tables.contains(&table_name.to_string()) { if self.tables.contains(&table_name) {
found = Some(table_name.to_string()); ControlFlow::Break(table_name)
ControlFlow::<()>::Break(())
} else { } else {
ControlFlow::<()>::Continue(()) ControlFlow::Continue(())
} }
}); });
if let Some(found) = found { if let ControlFlow::Break(found) = control_flow {
debug!("Blocking access to table \"{}\"", found); debug!("Blocking access to table \"{found}\"");
Ok(PluginOutput::Deny(format!( Ok(PluginOutput::Deny(format!(
"permission for table \"{}\" denied", "permission for table \"{found}\" denied",
found
))) )))
} else { } else {
Ok(PluginOutput::Allow) Ok(PluginOutput::Allow)

View File

@@ -229,20 +229,17 @@ impl ConnectionPool {
let old_pool_ref = get_pool(pool_name, &user.username); let old_pool_ref = get_pool(pool_name, &user.username);
let identifier = PoolIdentifier::new(pool_name, &user.username); let identifier = PoolIdentifier::new(pool_name, &user.username);
match old_pool_ref { if let Some(pool) = old_pool_ref {
Some(pool) => { // If the pool hasn't changed, get existing reference and insert it into the new_pools.
// If the pool hasn't changed, get existing reference and insert it into the new_pools. // We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8).
// We replace all pools at the end, but if the reference is kept, the pool won't get re-created (bb8). if pool.config_hash == new_pool_hash_value {
if pool.config_hash == new_pool_hash_value { info!(
info!( "[pool: {}][user: {}] has not changed",
"[pool: {}][user: {}] has not changed", pool_name, user.username
pool_name, user.username );
); new_pools.insert(identifier.clone(), pool.clone());
new_pools.insert(identifier.clone(), pool.clone()); continue;
continue;
}
} }
None => (),
} }
info!( info!(
@@ -628,7 +625,7 @@ impl ConnectionPool {
let mut force_healthcheck = false; let mut force_healthcheck = false;
if self.is_banned(address) { if self.is_banned(address) {
if self.try_unban(&address).await { if self.try_unban(address).await {
force_healthcheck = true; force_healthcheck = true;
} else { } else {
debug!("Address {:?} is banned", address); debug!("Address {:?} is banned", address);
@@ -748,8 +745,8 @@ impl ConnectionPool {
// Don't leave a bad connection in the pool. // Don't leave a bad connection in the pool.
server.mark_bad(); server.mark_bad();
self.ban(&address, BanReason::FailedHealthCheck, Some(client_info)); self.ban(address, BanReason::FailedHealthCheck, Some(client_info));
return false; false
} }
/// Ban an address (i.e. replica). It no longer will serve /// Ban an address (i.e. replica). It no longer will serve
@@ -861,10 +858,10 @@ impl ConnectionPool {
let guard = self.banlist.read(); let guard = self.banlist.read();
for banlist in guard.iter() { for banlist in guard.iter() {
for (address, (reason, timestamp)) in banlist.iter() { for (address, (reason, timestamp)) in banlist.iter() {
bans.push((address.clone(), (reason.clone(), timestamp.clone()))); bans.push((address.clone(), (reason.clone(), *timestamp)));
} }
} }
return bans; bans
} }
/// Get the address from the host url /// Get the address from the host url
@@ -921,7 +918,7 @@ impl ConnectionPool {
} }
let busy = provisioned - idle; let busy = provisioned - idle;
debug!("{:?} has {:?} busy connections", address, busy); debug!("{:?} has {:?} busy connections", address, busy);
return busy; busy
} }
} }

View File

@@ -1,6 +1,6 @@
use hyper::service::{make_service_fn, service_fn}; use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Method, Request, Response, Server, StatusCode}; use hyper::{Body, Method, Request, Response, Server, StatusCode};
use log::{error, info, debug}; use log::{debug, error, info};
use phf::phf_map; use phf::phf_map;
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt; use std::fmt;
@@ -364,7 +364,7 @@ fn push_server_stats(lines: &mut Vec<String>) {
{ {
lines.push(prometheus_metric.to_string()); lines.push(prometheus_metric.to_string());
} else { } else {
warn!("Metric {} not implemented for {}", key, address.name()); debug!("Metric {} not implemented for {}", key, address.name());
} }
} }
} }

View File

@@ -67,6 +67,7 @@ static CUSTOM_SQL_REGEX_SET: OnceCell<RegexSet> = OnceCell::new();
static CUSTOM_SQL_REGEX_LIST: OnceCell<Vec<Regex>> = OnceCell::new(); static CUSTOM_SQL_REGEX_LIST: OnceCell<Vec<Regex>> = OnceCell::new();
/// The query router. /// The query router.
#[derive(Default)]
pub struct QueryRouter { pub struct QueryRouter {
/// Which shard we should be talking to right now. /// Which shard we should be talking to right now.
active_shard: Option<usize>, active_shard: Option<usize>,
@@ -91,7 +92,7 @@ impl QueryRouter {
/// One-time initialization of regexes /// One-time initialization of regexes
/// that parse our custom SQL protocol. /// that parse our custom SQL protocol.
pub fn setup() -> bool { pub fn setup() -> bool {
let set = match RegexSet::new(&CUSTOM_SQL_REGEXES) { let set = match RegexSet::new(CUSTOM_SQL_REGEXES) {
Ok(rgx) => rgx, Ok(rgx) => rgx,
Err(err) => { Err(err) => {
error!("QueryRouter::setup Could not compile regex set: {:?}", err); error!("QueryRouter::setup Could not compile regex set: {:?}", err);
@@ -116,15 +117,8 @@ impl QueryRouter {
/// Create a new instance of the query router. /// Create a new instance of the query router.
/// Each client gets its own. /// Each client gets its own.
pub fn new() -> QueryRouter { pub fn new() -> Self {
QueryRouter { Self::default()
active_shard: None,
active_role: None,
query_parser_enabled: None,
primary_reads_enabled: None,
pool_settings: PoolSettings::default(),
placeholders: Vec::new(),
}
} }
/// Pool settings can change because of a config reload. /// Pool settings can change because of a config reload.
@@ -132,7 +126,7 @@ impl QueryRouter {
self.pool_settings = pool_settings; self.pool_settings = pool_settings;
} }
pub fn pool_settings<'a>(&'a self) -> &'a PoolSettings { pub fn pool_settings(&self) -> &PoolSettings {
&self.pool_settings &self.pool_settings
} }
@@ -143,7 +137,7 @@ impl QueryRouter {
let code = message_cursor.get_u8() as char; let code = message_cursor.get_u8() as char;
// Check for any sharding regex matches in any queries // Check for any sharding regex matches in any queries
match code as char { match code {
// For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement // For Parse and Query messages peek to see if they specify a shard_id as a comment early in the statement
'P' | 'Q' => { 'P' | 'Q' => {
if self.pool_settings.shard_id_regex.is_some() if self.pool_settings.shard_id_regex.is_some()
@@ -397,14 +391,10 @@ impl QueryRouter {
// or discard shard selection. If they point to the same shard though, // or discard shard selection. If they point to the same shard though,
// we can let them through as-is. // we can let them through as-is.
// This is basically building a database now :) // This is basically building a database now :)
match self.infer_shard(query) { if let Some(shard) = self.infer_shard(query) {
Some(shard) => { self.active_shard = Some(shard);
self.active_shard = Some(shard); debug!("Automatically using shard: {:?}", self.active_shard);
debug!("Automatically using shard: {:?}", self.active_shard); }
}
None => (),
};
} }
None => (), None => (),
@@ -576,8 +566,8 @@ impl QueryRouter {
.automatic_sharding_key .automatic_sharding_key
.as_ref() .as_ref()
.unwrap() .unwrap()
.split(".") .split('.')
.map(|ident| Ident::new(ident)) .map(Ident::new)
.collect::<Vec<Ident>>(); .collect::<Vec<Ident>>();
// Sharding key must be always fully qualified // Sharding key must be always fully qualified
@@ -593,7 +583,7 @@ impl QueryRouter {
Expr::Identifier(ident) => { Expr::Identifier(ident) => {
// Only if we're dealing with only one table // Only if we're dealing with only one table
// and there is no ambiguity // and there is no ambiguity
if &ident.value == &sharding_key[1].value { if ident.value == sharding_key[1].value {
// Sharding key is unique enough, don't worry about // Sharding key is unique enough, don't worry about
// table names. // table names.
if &sharding_key[0].value == "*" { if &sharding_key[0].value == "*" {
@@ -606,13 +596,13 @@ impl QueryRouter {
// SELECT * FROM t WHERE sharding_key = 5 // SELECT * FROM t WHERE sharding_key = 5
// Make sure the table name from the sharding key matches // Make sure the table name from the sharding key matches
// the table name from the query. // the table name from the query.
found = &sharding_key[0].value == &table[0].value; found = sharding_key[0].value == table[0].value;
} else if table.len() == 2 { } else if table.len() == 2 {
// Table name is fully qualified with the schema: e.g. // Table name is fully qualified with the schema: e.g.
// SELECT * FROM public.t WHERE sharding_key = 5 // SELECT * FROM public.t WHERE sharding_key = 5
// Ignore the schema (TODO: at some point, we want schema support) // Ignore the schema (TODO: at some point, we want schema support)
// and use the table name only. // and use the table name only.
found = &sharding_key[0].value == &table[1].value; found = sharding_key[0].value == table[1].value;
} else { } else {
debug!("Got table name with more than two idents, which is not possible"); debug!("Got table name with more than two idents, which is not possible");
} }
@@ -624,8 +614,8 @@ impl QueryRouter {
// The key is fully qualified in the query, // The key is fully qualified in the query,
// it will exist or Postgres will throw an error. // it will exist or Postgres will throw an error.
if idents.len() == 2 { if idents.len() == 2 {
found = &sharding_key[0].value == &idents[0].value found = sharding_key[0].value == idents[0].value
&& &sharding_key[1].value == &idents[1].value; && sharding_key[1].value == idents[1].value;
} }
// TODO: key can have schema as well, e.g. public.data.id (len == 3) // TODO: key can have schema as well, e.g. public.data.id (len == 3)
} }
@@ -657,7 +647,7 @@ impl QueryRouter {
} }
Expr::Value(Value::Placeholder(placeholder)) => { Expr::Value(Value::Placeholder(placeholder)) => {
match placeholder.replace("$", "").parse::<i16>() { match placeholder.replace('$', "").parse::<i16>() {
Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)), Ok(placeholder) => result.push(ShardingKey::Placeholder(placeholder)),
Err(_) => { Err(_) => {
debug!( debug!(
@@ -683,12 +673,9 @@ impl QueryRouter {
match &*query.body { match &*query.body {
SetExpr::Query(query) => { SetExpr::Query(query) => {
match self.infer_shard(&*query) { if let Some(shard) = self.infer_shard(query) {
Some(shard) => { shards.insert(shard);
shards.insert(shard); }
}
None => (),
};
} }
// SELECT * FROM ... // SELECT * FROM ...
@@ -698,38 +685,22 @@ impl QueryRouter {
let mut table_names = Vec::new(); let mut table_names = Vec::new();
for table in select.from.iter() { for table in select.from.iter() {
match &table.relation { if let TableFactor::Table { name, .. } = &table.relation {
TableFactor::Table { name, .. } => { table_names.push(name.0.clone());
table_names.push(name.0.clone()); }
}
_ => (),
};
// Get table names from all the joins. // Get table names from all the joins.
for join in table.joins.iter() { for join in table.joins.iter() {
match &join.relation { if let TableFactor::Table { name, .. } = &join.relation {
TableFactor::Table { name, .. } => { table_names.push(name.0.clone());
table_names.push(name.0.clone()); }
}
_ => (),
};
// We can filter results based on join conditions, e.g. // We can filter results based on join conditions, e.g.
// SELECT * FROM t INNER JOIN B ON B.sharding_key = 5; // SELECT * FROM t INNER JOIN B ON B.sharding_key = 5;
match &join.join_operator { if let JoinOperator::Inner(JoinConstraint::On(expr)) = &join.join_operator {
JoinOperator::Inner(inner_join) => match &inner_join { // Parse the selection criteria later.
JoinConstraint::On(expr) => { exprs.push(expr.clone());
// Parse the selection criteria later. }
exprs.push(expr.clone());
}
_ => (),
},
_ => (),
};
} }
} }
@@ -803,16 +774,16 @@ impl QueryRouter {
db: &self.pool_settings.db, db: &self.pool_settings.db,
}; };
let _ = query_logger.run(&self, ast).await; let _ = query_logger.run(self, ast).await;
} }
if let Some(ref intercept) = plugins.intercept { if let Some(ref intercept) = plugins.intercept {
let mut intercept = Intercept { let mut intercept = Intercept {
enabled: intercept.enabled, enabled: intercept.enabled,
config: &intercept, config: intercept,
}; };
let result = intercept.run(&self, ast).await; let result = intercept.run(self, ast).await;
if let Ok(PluginOutput::Intercept(output)) = result { if let Ok(PluginOutput::Intercept(output)) = result {
return Ok(PluginOutput::Intercept(output)); return Ok(PluginOutput::Intercept(output));
@@ -825,7 +796,7 @@ impl QueryRouter {
tables: &table_access.tables, tables: &table_access.tables,
}; };
let result = table_access.run(&self, ast).await; let result = table_access.run(self, ast).await;
if let Ok(PluginOutput::Deny(error)) = result { if let Ok(PluginOutput::Deny(error)) = result {
return Ok(PluginOutput::Deny(error)); return Ok(PluginOutput::Deny(error));
@@ -861,7 +832,7 @@ impl QueryRouter {
/// Should we attempt to parse queries? /// Should we attempt to parse queries?
pub fn query_parser_enabled(&self) -> bool { pub fn query_parser_enabled(&self) -> bool {
let enabled = match self.query_parser_enabled { match self.query_parser_enabled {
None => { None => {
debug!( debug!(
"Using pool settings, query_parser_enabled: {}", "Using pool settings, query_parser_enabled: {}",
@@ -877,9 +848,7 @@ impl QueryRouter {
); );
value value
} }
}; }
enabled
} }
pub fn primary_reads_enabled(&self) -> bool { pub fn primary_reads_enabled(&self) -> bool {
@@ -910,10 +879,14 @@ mod test {
fn test_infer_replica() { fn test_infer_replica() {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
assert!(qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")) != None); assert!(qr
.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"))
.is_some());
assert!(qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
let queries = vec![ let queries = vec![
simple_query("SELECT * FROM items WHERE id = 5"), simple_query("SELECT * FROM items WHERE id = 5"),
@@ -954,7 +927,9 @@ mod test {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
let query = simple_query("SELECT * FROM items WHERE id = 5"); let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO on")) != None); assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO on"))
.is_some());
assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok()); assert!(qr.infer(&QueryRouter::parse(&query).unwrap()).is_ok());
assert_eq!(qr.role(), None); assert_eq!(qr.role(), None);
@@ -965,7 +940,9 @@ mod test {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'")); qr.try_execute_command(&simple_query("SET SERVER ROLE TO 'auto'"));
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
let prepared_stmt = BytesMut::from( let prepared_stmt = BytesMut::from(
&b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..], &b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..],
@@ -1133,9 +1110,11 @@ mod test {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
let query = simple_query("SET SERVER ROLE TO 'auto'"); let query = simple_query("SET SERVER ROLE TO 'auto'");
assert!(qr.try_execute_command(&simple_query("SET PRIMARY READS TO off")) != None); assert!(qr
.try_execute_command(&simple_query("SET PRIMARY READS TO off"))
.is_some());
assert!(qr.try_execute_command(&query) != None); assert!(qr.try_execute_command(&query).is_some());
assert!(qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
assert_eq!(qr.role(), None); assert_eq!(qr.role(), None);
@@ -1149,7 +1128,7 @@ mod test {
assert!(qr.query_parser_enabled()); assert!(qr.query_parser_enabled());
let query = simple_query("SET SERVER ROLE TO 'default'"); let query = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(&query) != None); assert!(qr.try_execute_command(&query).is_some());
assert!(!qr.query_parser_enabled()); assert!(!qr.query_parser_enabled());
} }
@@ -1194,11 +1173,11 @@ mod test {
assert!(!qr.primary_reads_enabled()); assert!(!qr.primary_reads_enabled());
let q1 = simple_query("SET SERVER ROLE TO 'primary'"); let q1 = simple_query("SET SERVER ROLE TO 'primary'");
assert!(qr.try_execute_command(&q1) != None); assert!(qr.try_execute_command(&q1).is_some());
assert_eq!(qr.active_role.unwrap(), Role::Primary); assert_eq!(qr.active_role.unwrap(), Role::Primary);
let q2 = simple_query("SET SERVER ROLE TO 'default'"); let q2 = simple_query("SET SERVER ROLE TO 'default'");
assert!(qr.try_execute_command(&q2) != None); assert!(qr.try_execute_command(&q2).is_some());
assert_eq!(qr.active_role.unwrap(), pool_settings.default_role); assert_eq!(qr.active_role.unwrap(), pool_settings.default_role);
} }
@@ -1263,17 +1242,17 @@ mod test {
// Make sure setting it works // Make sure setting it works
let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;"); let q1 = simple_query("/* shard_id: 1 */ select 1 from foo;");
assert!(qr.try_execute_command(&q1) == None); assert!(qr.try_execute_command(&q1).is_none());
assert_eq!(qr.active_shard, Some(1)); assert_eq!(qr.active_shard, Some(1));
// And make sure changing it works // And make sure changing it works
let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;"); let q2 = simple_query("/* shard_id: 0 */ select 1 from foo;");
assert!(qr.try_execute_command(&q2) == None); assert!(qr.try_execute_command(&q2).is_none());
assert_eq!(qr.active_shard, Some(0)); assert_eq!(qr.active_shard, Some(0));
// Validate setting by shard with expected shard copied from sharding.rs tests // Validate setting by shard with expected shard copied from sharding.rs tests
let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;"); let q2 = simple_query("/* sharding_key: 6 */ select 1 from foo;");
assert!(qr.try_execute_command(&q2) == None); assert!(qr.try_execute_command(&q2).is_none());
assert_eq!(qr.active_shard, Some(2)); assert_eq!(qr.active_shard, Some(2));
} }
@@ -1411,9 +1390,11 @@ mod test {
}; };
QueryRouter::setup(); QueryRouter::setup();
let mut pool_settings = PoolSettings::default(); let pool_settings = PoolSettings {
pool_settings.query_parser_enabled = true; query_parser_enabled: true,
pool_settings.plugins = Some(plugins); plugins: Some(plugins),
..Default::default()
};
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new();
qr.update_pool_settings(pool_settings); qr.update_pool_settings(pool_settings);

View File

@@ -79,12 +79,12 @@ impl ScramSha256 {
let server_message = Message::parse(message)?; let server_message = Message::parse(message)?;
if !server_message.nonce.starts_with(&self.nonce) { if !server_message.nonce.starts_with(&self.nonce) {
return Err(Error::ProtocolSyncError(format!("SCRAM"))); return Err(Error::ProtocolSyncError("SCRAM".to_string()));
} }
let salt = match general_purpose::STANDARD.decode(&server_message.salt) { let salt = match general_purpose::STANDARD.decode(&server_message.salt) {
Ok(salt) => salt, Ok(salt) => salt,
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
}; };
let salted_password = Self::hi( let salted_password = Self::hi(
@@ -166,9 +166,9 @@ impl ScramSha256 {
pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> { pub fn finish(&mut self, message: &BytesMut) -> Result<(), Error> {
let final_message = FinalMessage::parse(message)?; let final_message = FinalMessage::parse(message)?;
let verifier = match general_purpose::STANDARD.decode(&final_message.value) { let verifier = match general_purpose::STANDARD.decode(final_message.value) {
Ok(verifier) => verifier, Ok(verifier) => verifier,
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
}; };
let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) { let mut hmac = match Hmac::<Sha256>::new_from_slice(&self.salted_password) {
@@ -230,14 +230,14 @@ impl Message {
.collect::<Vec<String>>(); .collect::<Vec<String>>();
if parts.len() != 3 { if parts.len() != 3 {
return Err(Error::ProtocolSyncError(format!("SCRAM"))); return Err(Error::ProtocolSyncError("SCRAM".to_string()));
} }
let nonce = str::replace(&parts[0], "r=", ""); let nonce = str::replace(&parts[0], "r=", "");
let salt = str::replace(&parts[1], "s=", ""); let salt = str::replace(&parts[1], "s=", "");
let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() { let iterations = match str::replace(&parts[2], "i=", "").parse::<u32>() {
Ok(iterations) => iterations, Ok(iterations) => iterations,
Err(_) => return Err(Error::ProtocolSyncError(format!("SCRAM"))), Err(_) => return Err(Error::ProtocolSyncError("SCRAM".to_string())),
}; };
Ok(Message { Ok(Message {
@@ -257,7 +257,7 @@ impl FinalMessage {
/// Parse the server final validation message. /// Parse the server final validation message.
pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> { pub fn parse(message: &BytesMut) -> Result<FinalMessage, Error> {
if !message.starts_with(b"v=") || message.len() < 4 { if !message.starts_with(b"v=") || message.len() < 4 {
return Err(Error::ProtocolSyncError(format!("SCRAM"))); return Err(Error::ProtocolSyncError("SCRAM".to_string()));
} }
Ok(FinalMessage { Ok(FinalMessage {

View File

@@ -15,7 +15,7 @@ use tokio::net::TcpStream;
use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore}; use tokio_rustls::rustls::{OwnedTrustAnchor, RootCertStore};
use tokio_rustls::{client::TlsStream, TlsConnector}; use tokio_rustls::{client::TlsStream, TlsConnector};
use crate::config::{get_config, Address, User}; use crate::config::{get_config, get_prepared_statements_cache_size, Address, User};
use crate::constants::*; use crate::constants::*;
use crate::dns_cache::{AddrSet, CACHED_RESOLVER}; use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
use crate::errors::{Error, ServerIdentifier}; use crate::errors::{Error, ServerIdentifier};
@@ -170,6 +170,9 @@ pub struct Server {
/// Is there more data for the client to read. /// Is there more data for the client to read.
data_available: bool, data_available: bool,
/// Is the server in copy-in or copy-out modes
in_copy_mode: bool,
/// Is the server broken? We'll remote it from the pool if so. /// Is the server broken? We'll remote it from the pool if so.
bad: bool, bad: bool,
@@ -313,10 +316,7 @@ impl Server {
// Something else? // Something else?
m => { m => {
return Err(Error::SocketError(format!( return Err(Error::SocketError(format!("Unknown message: {}", { m })));
"Unknown message: {}",
m as char
)));
} }
} }
} else { } else {
@@ -334,27 +334,18 @@ impl Server {
None => &user.username, None => &user.username,
}; };
let password = match user.server_password { let password = user.server_password.as_ref();
Some(ref server_password) => Some(server_password),
None => match user.password {
Some(ref password) => Some(password),
None => None,
},
};
startup(&mut stream, username, database).await?; startup(&mut stream, username, database).await?;
let mut server_info = BytesMut::new(); let mut server_info = BytesMut::new();
let mut process_id: i32 = 0; let mut process_id: i32 = 0;
let mut secret_key: i32 = 0; let mut secret_key: i32 = 0;
let server_identifier = ServerIdentifier::new(username, &database); let server_identifier = ServerIdentifier::new(username, database);
// We'll be handling multiple packets, but they will all be structured the same. // We'll be handling multiple packets, but they will all be structured the same.
// We'll loop here until this exchange is complete. // We'll loop here until this exchange is complete.
let mut scram: Option<ScramSha256> = match password { let mut scram: Option<ScramSha256> = password.map(|password| ScramSha256::new(password));
Some(password) => Some(ScramSha256::new(password)),
None => None,
};
loop { loop {
let code = match stream.read_u8().await { let code = match stream.read_u8().await {
@@ -677,6 +668,7 @@ impl Server {
process_id, process_id,
secret_key, secret_key,
in_transaction: false, in_transaction: false,
in_copy_mode: false,
data_available: false, data_available: false,
bad: false, bad: false,
cleanup_state: CleanupState::new(), cleanup_state: CleanupState::new(),
@@ -749,7 +741,7 @@ impl Server {
self.mirror_send(messages); self.mirror_send(messages);
self.stats().data_sent(messages.len()); self.stats().data_sent(messages.len());
match write_all_flush(&mut self.stream, &messages).await { match write_all_flush(&mut self.stream, messages).await {
Ok(_) => { Ok(_) => {
// Successfully sent to server // Successfully sent to server
self.last_activity = SystemTime::now(); self.last_activity = SystemTime::now();
@@ -828,8 +820,19 @@ impl Server {
break; break;
} }
// ErrorResponse
'E' => {
if self.in_copy_mode {
self.in_copy_mode = false;
}
}
// CommandComplete // CommandComplete
'C' => { 'C' => {
if self.in_copy_mode {
self.in_copy_mode = false;
}
let mut command_tag = String::new(); let mut command_tag = String::new();
match message.reader().read_to_string(&mut command_tag) { match message.reader().read_to_string(&mut command_tag) {
Ok(_) => { Ok(_) => {
@@ -873,10 +876,14 @@ impl Server {
} }
// CopyInResponse: copy is starting from client to server. // CopyInResponse: copy is starting from client to server.
'G' => break, 'G' => {
self.in_copy_mode = true;
break;
}
// CopyOutResponse: copy is starting from the server to the client. // CopyOutResponse: copy is starting from the server to the client.
'H' => { 'H' => {
self.in_copy_mode = true;
self.data_available = true; self.data_available = true;
break; break;
} }
@@ -914,12 +921,16 @@ impl Server {
Ok(bytes) Ok(bytes)
} }
/// Add the prepared statement to being tracked by this server.
/// The client is processing data that will create a prepared statement on this server.
pub fn will_prepare(&mut self, name: &str) { pub fn will_prepare(&mut self, name: &str) {
debug!("Will prepare `{}`", name); debug!("Will prepare `{}`", name);
self.prepared_statements.insert(name.to_string()); self.prepared_statements.insert(name.to_string());
self.stats.prepared_cache_add();
} }
/// Check if we should prepare a statement on the server.
pub fn should_prepare(&self, name: &str) -> bool { pub fn should_prepare(&self, name: &str) -> bool {
let should_prepare = !self.prepared_statements.contains(name); let should_prepare = !self.prepared_statements.contains(name);
@@ -934,6 +945,7 @@ impl Server {
should_prepare should_prepare
} }
/// Create a prepared statement on the server.
pub async fn prepare(&mut self, parse: &Parse) -> Result<(), Error> { pub async fn prepare(&mut self, parse: &Parse) -> Result<(), Error> {
debug!("Preparing `{}`", parse.name); debug!("Preparing `{}`", parse.name);
@@ -942,15 +954,82 @@ impl Server {
self.send(&flush()).await?; self.send(&flush()).await?;
// Read and discard ParseComplete (B) // Read and discard ParseComplete (B)
let _ = read_message(&mut self.stream).await?; match read_message(&mut self.stream).await {
Ok(_) => (),
Err(err) => {
self.bad = true;
return Err(err);
}
}
self.prepared_statements.insert(parse.name.to_string()); self.prepared_statements.insert(parse.name.to_string());
self.stats.prepared_cache_add();
debug!("Prepared `{}`", parse.name); debug!("Prepared `{}`", parse.name);
Ok(()) Ok(())
} }
/// Maintain adequate cache size on the server.
pub async fn maintain_cache(&mut self) -> Result<(), Error> {
debug!("Cache maintenance run");
let max_cache_size = get_prepared_statements_cache_size();
let mut names = Vec::new();
while self.prepared_statements.len() >= max_cache_size {
// The prepared statmeents are alphanumerically sorted by the BTree.
// FIFO.
if let Some(name) = self.prepared_statements.pop_last() {
names.push(name);
}
}
self.deallocate(names).await?;
Ok(())
}
/// Remove the prepared statement from being tracked by this server.
/// The client is processing data that will cause the server to close the prepared statement.
pub fn will_close(&mut self, name: &str) {
debug!("Will close `{}`", name);
self.prepared_statements.remove(name);
}
/// Close a prepared statement on the server.
pub async fn deallocate(&mut self, names: Vec<String>) -> Result<(), Error> {
for name in &names {
debug!("Deallocating prepared statement `{}`", name);
let close = Close::new(name);
let bytes: BytesMut = close.try_into()?;
self.send(&bytes).await?;
}
self.send(&flush()).await?;
// Read and discard CloseComplete (3)
for name in &names {
match read_message(&mut self.stream).await {
Ok(_) => {
self.prepared_statements.remove(name);
self.stats.prepared_cache_remove();
debug!("Closed `{}`", name);
}
Err(err) => {
self.bad = true;
return Err(err);
}
};
}
Ok(())
}
/// If the server is still inside a transaction. /// If the server is still inside a transaction.
/// If the client disconnects while the server is in a transaction, we will clean it up. /// If the client disconnects while the server is in a transaction, we will clean it up.
pub fn in_transaction(&self) -> bool { pub fn in_transaction(&self) -> bool {
@@ -958,6 +1037,10 @@ impl Server {
self.in_transaction self.in_transaction
} }
pub fn in_copy_mode(&self) -> bool {
self.in_copy_mode
}
/// We don't buffer all of server responses, e.g. COPY OUT produces too much data. /// We don't buffer all of server responses, e.g. COPY OUT produces too much data.
/// The client is responsible to call `self.recv()` while this method returns true. /// The client is responsible to call `self.recv()` while this method returns true.
pub fn is_data_available(&self) -> bool { pub fn is_data_available(&self) -> bool {
@@ -1057,6 +1140,10 @@ impl Server {
self.cleanup_state.reset(); self.cleanup_state.reset();
} }
if self.in_copy_mode() {
warn!("Server returned while still in copy-mode");
}
Ok(()) Ok(())
} }
@@ -1100,16 +1187,14 @@ impl Server {
} }
pub fn mirror_send(&mut self, bytes: &BytesMut) { pub fn mirror_send(&mut self, bytes: &BytesMut) {
match self.mirror_manager.as_mut() { if let Some(manager) = self.mirror_manager.as_mut() {
Some(manager) => manager.send(bytes), manager.send(bytes);
None => (),
} }
} }
pub fn mirror_disconnect(&mut self) { pub fn mirror_disconnect(&mut self) {
match self.mirror_manager.as_mut() { if let Some(manager) = self.mirror_manager.as_mut() {
Some(manager) => manager.disconnect(), manager.disconnect();
None => (),
} }
} }
@@ -1137,7 +1222,7 @@ impl Server {
server.send(&simple_query(query)).await?; server.send(&simple_query(query)).await?;
let mut message = server.recv().await?; let mut message = server.recv().await?;
Ok(parse_query_message(&mut message).await?) parse_query_message(&mut message).await
} }
} }

View File

@@ -64,7 +64,7 @@ impl Sharder {
fn sha1(&self, key: i64) -> usize { fn sha1(&self, key: i64) -> usize {
let mut hasher = Sha1::new(); let mut hasher = Sha1::new();
hasher.update(&key.to_string().as_bytes()); hasher.update(key.to_string().as_bytes());
let result = hasher.finalize(); let result = hasher.finalize();

View File

@@ -77,13 +77,12 @@ impl Reporter {
/// The statistics collector which used for calculating averages /// The statistics collector which used for calculating averages
/// There is only one collector (kind of like a singleton) /// There is only one collector (kind of like a singleton)
/// it updates averages every 15 seconds. /// it updates averages every 15 seconds.
#[derive(Default)] pub struct Collector;
pub struct Collector {}
impl Collector { impl Collector {
/// The statistics collection handler. It will collect statistics /// The statistics collection handler. It will collect statistics
/// for `address_id`s starting at 0 up to `addresses`. /// for `address_id`s starting at 0 up to `addresses`.
pub async fn collect(&mut self) { pub fn collect() {
info!("Events reporter started"); info!("Events reporter started");
tokio::task::spawn(async move { tokio::task::spawn(async move {

View File

@@ -86,11 +86,11 @@ impl PoolStats {
} }
} }
return map; map
} }
pub fn generate_header() -> Vec<(&'static str, DataType)> { pub fn generate_header() -> Vec<(&'static str, DataType)> {
return vec![ vec![
("database", DataType::Text), ("database", DataType::Text),
("user", DataType::Text), ("user", DataType::Text),
("pool_mode", DataType::Text), ("pool_mode", DataType::Text),
@@ -105,11 +105,11 @@ impl PoolStats {
("sv_login", DataType::Numeric), ("sv_login", DataType::Numeric),
("maxwait", DataType::Numeric), ("maxwait", DataType::Numeric),
("maxwait_us", DataType::Numeric), ("maxwait_us", DataType::Numeric),
]; ]
} }
pub fn generate_row(&self) -> Vec<String> { pub fn generate_row(&self) -> Vec<String> {
return vec![ vec![
self.identifier.db.clone(), self.identifier.db.clone(),
self.identifier.user.clone(), self.identifier.user.clone(),
self.mode.to_string(), self.mode.to_string(),
@@ -124,7 +124,7 @@ impl PoolStats {
self.sv_login.to_string(), self.sv_login.to_string(),
(self.maxwait / 1_000_000).to_string(), (self.maxwait / 1_000_000).to_string(),
(self.maxwait % 1_000_000).to_string(), (self.maxwait % 1_000_000).to_string(),
]; ]
} }
} }

View File

@@ -49,6 +49,7 @@ pub struct ServerStats {
pub error_count: Arc<AtomicU64>, pub error_count: Arc<AtomicU64>,
pub prepared_hit_count: Arc<AtomicU64>, pub prepared_hit_count: Arc<AtomicU64>,
pub prepared_miss_count: Arc<AtomicU64>, pub prepared_miss_count: Arc<AtomicU64>,
pub prepared_cache_size: Arc<AtomicU64>,
} }
impl Default for ServerStats { impl Default for ServerStats {
@@ -67,6 +68,7 @@ impl Default for ServerStats {
reporter: get_reporter(), reporter: get_reporter(),
prepared_hit_count: Arc::new(AtomicU64::new(0)), prepared_hit_count: Arc::new(AtomicU64::new(0)),
prepared_miss_count: Arc::new(AtomicU64::new(0)), prepared_miss_count: Arc::new(AtomicU64::new(0)),
prepared_cache_size: Arc::new(AtomicU64::new(0)),
} }
} }
} }
@@ -213,4 +215,12 @@ impl ServerStats {
pub fn prepared_cache_miss(&self) { pub fn prepared_cache_miss(&self) {
self.prepared_miss_count.fetch_add(1, Ordering::Relaxed); self.prepared_miss_count.fetch_add(1, Ordering::Relaxed);
} }
pub fn prepared_cache_add(&self) {
self.prepared_cache_size.fetch_add(1, Ordering::Relaxed);
}
pub fn prepared_cache_remove(&self) {
self.prepared_cache_size.fetch_sub(1, Ordering::Relaxed);
}
} }

View File

@@ -44,25 +44,17 @@ impl Tls {
pub fn new() -> Result<Self, Error> { pub fn new() -> Result<Self, Error> {
let config = get_config(); let config = get_config();
let certs = match load_certs(Path::new(&config.general.tls_certificate.unwrap())) { let certs = load_certs(Path::new(&config.general.tls_certificate.unwrap()))
Ok(certs) => certs, .map_err(|_| Error::TlsError)?;
Err(_) => return Err(Error::TlsError), let key_der = load_keys(Path::new(&config.general.tls_private_key.unwrap()))
}; .map_err(|_| Error::TlsError)?
.remove(0);
let mut keys = match load_keys(Path::new(&config.general.tls_private_key.unwrap())) { let config = rustls::ServerConfig::builder()
Ok(keys) => keys,
Err(_) => return Err(Error::TlsError),
};
let config = match rustls::ServerConfig::builder()
.with_safe_defaults() .with_safe_defaults()
.with_no_client_auth() .with_no_client_auth()
.with_single_cert(certs, keys.remove(0)) .with_single_cert(certs, key_der)
.map_err(|err| std::io::Error::new(std::io::ErrorKind::InvalidInput, err)) .map_err(|_| Error::TlsError)?;
{
Ok(c) => c,
Err(_) => return Err(Error::TlsError),
};
Ok(Tls { Ok(Tls {
acceptor: TlsAcceptor::from(Arc::new(config)), acceptor: TlsAcceptor::from(Arc::new(config)),

102
tests/ruby/copy_spec.rb Normal file
View File

@@ -0,0 +1,102 @@
# frozen_string_literal: true
require_relative 'spec_helper'
describe "COPY Handling" do
let(:processes) { Helpers::Pgcat.single_instance_setup("sharded_db", 5) }
before do
new_configs = processes.pgcat.current_config
# Allow connections in the pool to expire faster
new_configs["general"]["idle_timeout"] = 5
processes.pgcat.update_config(new_configs)
# We need to kill the old process that was using the default configs
processes.pgcat.stop
processes.pgcat.start
processes.pgcat.wait_until_ready
end
before do
processes.all_databases.first.with_connection do |conn|
conn.async_exec "CREATE TABLE copy_test_table (a TEXT,b TEXT,c TEXT,d TEXT)"
end
end
after do
processes.all_databases.first.with_connection do |conn|
conn.async_exec "DROP TABLE copy_test_table;"
end
end
after do
processes.all_databases.map(&:reset)
processes.pgcat.shutdown
end
describe "COPY FROM" do
context "within transaction" do
it "finishes within alloted time" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
Timeout.timeout(3) do
conn.async_exec("BEGIN")
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
sleep 0.5
conn.put_copy_data "some,data,to,copy\n"
conn.put_copy_data "more,data,to,copy\n"
end
conn.async_exec("COMMIT")
end
res = conn.async_exec("SELECT * FROM copy_test_table").to_a
expect(res).to eq([
{"a"=>"some", "b"=>"data", "c"=>"to", "d"=>"copy"},
{"a"=>"more", "b"=>"data", "c"=>"to", "d"=>"copy"}
])
end
end
context "outside transaction" do
it "finishes within alloted time" do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
Timeout.timeout(3) do
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
sleep 0.5
conn.put_copy_data "some,data,to,copy\n"
conn.put_copy_data "more,data,to,copy\n"
end
end
res = conn.async_exec("SELECT * FROM copy_test_table").to_a
expect(res).to eq([
{"a"=>"some", "b"=>"data", "c"=>"to", "d"=>"copy"},
{"a"=>"more", "b"=>"data", "c"=>"to", "d"=>"copy"}
])
end
end
end
describe "COPY TO" do
before do
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.async_exec("BEGIN")
conn.copy_data "COPY copy_test_table FROM STDIN CSV" do
conn.put_copy_data "some,data,to,copy\n"
conn.put_copy_data "more,data,to,copy\n"
end
conn.async_exec("COMMIT")
conn.close
end
it "works" do
res = []
conn = PG.connect(processes.pgcat.connection_string("sharded_db", "sharding_user"))
conn.copy_data "COPY copy_test_table TO STDOUT CSV" do
while row=conn.get_copy_data
res << row
end
end
expect(res).to eq(["some,data,to,copy\n", "more,data,to,copy\n"])
end
end
end

View File

@@ -0,0 +1,29 @@
require_relative 'spec_helper'
describe 'Prepared statements' do
let(:processes) { Helpers::Pgcat.three_shard_setup('sharded_db', 5) }
context 'enabled' do
it 'will work over the same connection' do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
10.times do |i|
statement_name = "statement_#{i}"
conn.prepare(statement_name, 'SELECT $1::int')
conn.exec_prepared(statement_name, [1])
conn.describe_prepared(statement_name)
end
end
it 'will work with new connections' do
10.times do
conn = PG.connect(processes.pgcat.connection_string('sharded_db', 'sharding_user'))
statement_name = 'statement1'
conn.prepare('statement1', 'SELECT $1::int')
conn.exec_prepared('statement1', [1])
conn.describe_prepared('statement1')
end
end
end
end