mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 09:26:30 +00:00
Compare commits
11 Commits
query-rout
...
sven_md5_a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f745859c0 | ||
|
|
1e8fa110ae | ||
|
|
d4186b7815 | ||
|
|
aaeef69d59 | ||
|
|
b21e0f4a7e | ||
|
|
eb1473060e | ||
|
|
26f75f8d5d | ||
|
|
99d65fc475 | ||
|
|
206fdc9769 | ||
|
|
f74101cdfe | ||
|
|
8e0682482d |
@@ -57,6 +57,17 @@ cd tests/ruby && \
|
||||
ruby tests.rb && \
|
||||
cd ../..
|
||||
|
||||
# Admin tests
|
||||
psql -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
|
||||
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null
|
||||
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null
|
||||
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null
|
||||
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW LISTS' > /dev/null
|
||||
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW POOLS' > /dev/null
|
||||
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW VERSION' > /dev/null
|
||||
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev/null # will ignore
|
||||
(! psql -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null)
|
||||
|
||||
# Start PgCat in debug to demonstrate failover better
|
||||
start_pgcat "debug"
|
||||
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,2 +1,5 @@
|
||||
/target
|
||||
*.deb
|
||||
.idea/*
|
||||
tests/ruby/.bundle/*
|
||||
tests/ruby/vendor/*
|
||||
24
Cargo.lock
generated
24
Cargo.lock
generated
@@ -220,6 +220,12 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35"
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.117"
|
||||
@@ -369,6 +375,7 @@ dependencies = [
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"sha-1",
|
||||
"sqlparser",
|
||||
"statsd",
|
||||
@@ -478,6 +485,12 @@ version = "0.6.25"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f"
|
||||
|
||||
[[package]]
|
||||
name = "scopeguard"
|
||||
version = "1.1.0"
|
||||
@@ -501,6 +514,17 @@ dependencies = [
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.79"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8e8d9fa5c3b304765ce1fd9c4c8a3de2c8db365a5b91be52f186efc675681d95"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sha-1"
|
||||
version = "0.10.0"
|
||||
|
||||
@@ -17,6 +17,7 @@ sha-1 = "0.10"
|
||||
toml = "0.5"
|
||||
serde = "1"
|
||||
serde_derive = "1"
|
||||
serde_json = "1"
|
||||
regex = "1"
|
||||
num_cpus = "1"
|
||||
once_cell = "1"
|
||||
|
||||
25
README.md
25
README.md
@@ -133,6 +133,31 @@ All servers are checked with a `SELECT 1` query before being given to a client.
|
||||
|
||||
The ban time can be changed with `ban_time`. The default is 60 seconds.
|
||||
|
||||
Failover behavior can get pretty interesting (read complex) when multiple configurations and factors are involved. The table below will try to explain what PgCat does in each scenario:
|
||||
|
||||
| **Query** | **`SET SERVER ROLE TO`** | **`query_parser_enabled`** | **`primary_reads_enabled`** | **Target state** | **Outcome** |
|
||||
|---------------------------|--------------------------|----------------------------|-----------------------------|------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| Read query, i.e. `SELECT` | unset (any) | false | false | up | Query is routed to the first instance in the round-robin loop. |
|
||||
| Read query | unset (any) | true | false | up | Query is routed to the first replica instance in the round-robin loop. |
|
||||
| Read query | unset (any) | true | true | up | Query is routed to the first instance in the round-robin loop. |
|
||||
| Read query | replica | false | false | up | Query is routed to the first replica instance in the round-robin loop. |
|
||||
| Read query | primary | false | false | up | Query is routed to the primary. |
|
||||
| Read query | unset (any) | false | false | down | First instance is banned for reads. Next target in the round-robin loop is attempted. |
|
||||
| Read query | unset (any) | true | false | down | First replica instance is banned. Next replica instance is attempted in the round-robin loop. |
|
||||
| Read query | unset (any) | true | true | down | First instance (even if primary) is banned for reads. Next instance is attempted in the round-robin loop. |
|
||||
| Read query | replica | false | false | down | First replica instance is banned. Next replica instance is attempted in the round-robin loop. |
|
||||
| Read query | primary | false | false | down | The query is attempted against the primary and fails. The client receives an error. |
|
||||
| | | | | | |
|
||||
| Write query e.g. `INSERT` | unset (any) | false | false | up | The query is attempted against the first available instance in the round-robin loop. If the instance is a replica, the query fails and the client receives an error. |
|
||||
| Write query | unset (any) | true | false | up | The query is routed to the primary. |
|
||||
| Write query | unset (any) | true | true | up | The query is routed to the primary. |
|
||||
| Write query | primary | false | false | up | The query is routed to the primary. |
|
||||
| Write query | replica | false | false | up | The query is routed to the replica and fails. The client receives an error. |
|
||||
| Write query | unset (any) | true | false | down | The query is routed to the primary and fails. The client receives an error. |
|
||||
| Write query | unset (any) | true | true | down | The query is routed to the primary and fails. The client receives an error. |
|
||||
| Write query | primary | false | false | down | The query is routed to the primary and fails. The client receives an error. |
|
||||
| | | | | | |
|
||||
|
||||
### Sharding
|
||||
We use the `PARTITION BY HASH` hashing function, the same as used by Postgres for declarative partitioning. This allows to shard the database using Postgres partitions and place the partitions on different servers (shards). Both read and write queries can be routed to the shards using this pooler.
|
||||
|
||||
|
||||
12
pgcat.toml
12
pgcat.toml
@@ -48,8 +48,8 @@ password = "sharding_user"
|
||||
|
||||
# [ host, port, role ]
|
||||
servers = [
|
||||
[ "127.0.0.1", 5432, "primary" ],
|
||||
[ "localhost", 5432, "replica" ],
|
||||
["127.0.0.1", 5432, "primary"],
|
||||
["localhost", 5432, "replica"],
|
||||
# [ "127.0.1.1", 5432, "replica" ],
|
||||
]
|
||||
# Database name (e.g. "postgres")
|
||||
@@ -58,8 +58,8 @@ database = "shard0"
|
||||
[shards.1]
|
||||
# [ host, port, role ]
|
||||
servers = [
|
||||
[ "127.0.0.1", 5432, "primary" ],
|
||||
[ "localhost", 5432, "replica" ],
|
||||
["127.0.0.1", 5432, "primary"],
|
||||
["localhost", 5432, "replica"],
|
||||
# [ "127.0.1.1", 5432, "replica" ],
|
||||
]
|
||||
database = "shard1"
|
||||
@@ -67,8 +67,8 @@ database = "shard1"
|
||||
[shards.2]
|
||||
# [ host, port, role ]
|
||||
servers = [
|
||||
[ "127.0.0.1", 5432, "primary" ],
|
||||
[ "localhost", 5432, "replica" ],
|
||||
["127.0.0.1", 5432, "primary"],
|
||||
["localhost", 5432, "replica"],
|
||||
# [ "127.0.1.1", 5432, "replica" ],
|
||||
]
|
||||
database = "shard2"
|
||||
|
||||
351
src/admin.rs
Normal file
351
src/admin.rs
Normal file
@@ -0,0 +1,351 @@
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use log::{info, trace};
|
||||
use tokio::net::tcp::OwnedWriteHalf;
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::config::{get_config, parse};
|
||||
use crate::errors::Error;
|
||||
use crate::messages::*;
|
||||
use crate::pool::ConnectionPool;
|
||||
use crate::stats::get_stats;
|
||||
|
||||
/// Handle admin client
|
||||
pub async fn handle_admin(
|
||||
stream: &mut OwnedWriteHalf,
|
||||
mut query: BytesMut,
|
||||
pool: ConnectionPool,
|
||||
) -> Result<(), Error> {
|
||||
let code = query.get_u8() as char;
|
||||
|
||||
if code != 'Q' {
|
||||
return Err(Error::ProtocolSyncError);
|
||||
}
|
||||
|
||||
let len = query.get_i32() as usize;
|
||||
let query = String::from_utf8_lossy(&query[..len - 5])
|
||||
.to_string()
|
||||
.to_ascii_uppercase();
|
||||
|
||||
trace!("Admin query: {}", query);
|
||||
|
||||
if query.starts_with("SHOW STATS") {
|
||||
trace!("SHOW STATS");
|
||||
show_stats(stream).await
|
||||
} else if query.starts_with("RELOAD") {
|
||||
trace!("RELOAD");
|
||||
reload(stream).await
|
||||
} else if query.starts_with("SHOW CONFIG") {
|
||||
trace!("SHOW CONFIG");
|
||||
show_config(stream).await
|
||||
} else if query.starts_with("SHOW DATABASES") {
|
||||
trace!("SHOW DATABASES");
|
||||
show_databases(stream, &pool).await
|
||||
} else if query.starts_with("SHOW POOLS") {
|
||||
trace!("SHOW POOLS");
|
||||
show_pools(stream, &pool).await
|
||||
} else if query.starts_with("SHOW LISTS") {
|
||||
trace!("SHOW LISTS");
|
||||
show_lists(stream, &pool).await
|
||||
} else if query.starts_with("SHOW VERSION") {
|
||||
trace!("SHOW VERSION");
|
||||
show_version(stream).await
|
||||
} else if query.starts_with("SET ") {
|
||||
trace!("SET");
|
||||
ignore_set(stream).await
|
||||
} else {
|
||||
error_response(stream, "Unsupported query against the admin database").await
|
||||
}
|
||||
}
|
||||
|
||||
/// SHOW LISTS
|
||||
async fn show_lists(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> {
|
||||
let stats = get_stats();
|
||||
|
||||
let columns = vec![("list", DataType::Text), ("items", DataType::Int4)];
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
res.put(row_description(&columns));
|
||||
res.put(data_row(&vec![
|
||||
"databases".to_string(),
|
||||
(pool.databases() + 1).to_string(), // see comment below
|
||||
]));
|
||||
res.put(data_row(&vec!["users".to_string(), "1".to_string()]));
|
||||
res.put(data_row(&vec![
|
||||
"pools".to_string(),
|
||||
(pool.databases() + 1).to_string(), // +1 for the pgbouncer admin db pool which isn't real
|
||||
])); // but admin tools that work with pgbouncer want this
|
||||
res.put(data_row(&vec![
|
||||
"free_clients".to_string(),
|
||||
stats["cl_idle"].to_string(),
|
||||
]));
|
||||
res.put(data_row(&vec![
|
||||
"used_clients".to_string(),
|
||||
stats["cl_active"].to_string(),
|
||||
]));
|
||||
res.put(data_row(&vec![
|
||||
"login_clients".to_string(),
|
||||
"0".to_string(),
|
||||
]));
|
||||
res.put(data_row(&vec![
|
||||
"free_servers".to_string(),
|
||||
stats["sv_idle"].to_string(),
|
||||
]));
|
||||
res.put(data_row(&vec![
|
||||
"used_servers".to_string(),
|
||||
stats["sv_active"].to_string(),
|
||||
]));
|
||||
res.put(data_row(&vec!["dns_names".to_string(), "0".to_string()]));
|
||||
res.put(data_row(&vec!["dns_zones".to_string(), "0".to_string()]));
|
||||
res.put(data_row(&vec!["dns_queries".to_string(), "0".to_string()]));
|
||||
res.put(data_row(&vec!["dns_pending".to_string(), "0".to_string()]));
|
||||
|
||||
res.put(command_complete("SHOW"));
|
||||
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, res).await
|
||||
}
|
||||
|
||||
/// SHOW VERSION
|
||||
async fn show_version(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
res.put(row_description(&vec![("version", DataType::Text)]));
|
||||
res.put(data_row(&vec!["PgCat 0.1.0".to_string()]));
|
||||
res.put(command_complete("SHOW"));
|
||||
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, res).await
|
||||
}
|
||||
|
||||
/// SHOW POOLS
|
||||
async fn show_pools(stream: &mut OwnedWriteHalf, _pool: &ConnectionPool) -> Result<(), Error> {
|
||||
let stats = get_stats();
|
||||
let config = {
|
||||
let guard = get_config();
|
||||
&*guard.clone()
|
||||
};
|
||||
|
||||
let columns = vec![
|
||||
("database", DataType::Text),
|
||||
("user", DataType::Text),
|
||||
("cl_active", DataType::Numeric),
|
||||
("cl_waiting", DataType::Numeric),
|
||||
("cl_cancel_req", DataType::Numeric),
|
||||
("sv_active", DataType::Numeric),
|
||||
("sv_idle", DataType::Numeric),
|
||||
("sv_used", DataType::Numeric),
|
||||
("sv_tested", DataType::Numeric),
|
||||
("sv_login", DataType::Numeric),
|
||||
("maxwait", DataType::Numeric),
|
||||
("maxwait_us", DataType::Numeric),
|
||||
("pool_mode", DataType::Text),
|
||||
];
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
res.put(row_description(&columns));
|
||||
|
||||
let mut row = vec![String::from("all"), config.user.name.clone()];
|
||||
|
||||
for column in &columns[2..columns.len() - 1] {
|
||||
let value = stats.get(column.0).unwrap_or(&0).to_string();
|
||||
row.push(value);
|
||||
}
|
||||
|
||||
row.push(config.general.pool_mode.to_string());
|
||||
|
||||
res.put(data_row(&row));
|
||||
res.put(command_complete("SHOW"));
|
||||
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, res).await
|
||||
}
|
||||
|
||||
/// SHOW DATABASES
|
||||
async fn show_databases(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> {
|
||||
let guard = get_config();
|
||||
let config = &*guard.clone();
|
||||
drop(guard);
|
||||
|
||||
// Columns
|
||||
let columns = vec![
|
||||
("name", DataType::Text),
|
||||
("host", DataType::Text),
|
||||
("port", DataType::Text),
|
||||
("database", DataType::Text),
|
||||
("force_user", DataType::Text),
|
||||
("pool_size", DataType::Int4),
|
||||
("min_pool_size", DataType::Int4),
|
||||
("reserve_pool", DataType::Int4),
|
||||
("pool_mode", DataType::Text),
|
||||
("max_connections", DataType::Int4),
|
||||
("current_connections", DataType::Int4),
|
||||
("paused", DataType::Int4),
|
||||
("disabled", DataType::Int4),
|
||||
];
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
// RowDescription
|
||||
res.put(row_description(&columns));
|
||||
|
||||
for shard in 0..pool.shards() {
|
||||
let database_name = &config.shards[&shard.to_string()].database;
|
||||
|
||||
for server in 0..pool.servers(shard) {
|
||||
let address = pool.address(shard, server);
|
||||
let pool_state = pool.pool_state(shard, server);
|
||||
|
||||
res.put(data_row(&vec![
|
||||
address.name(), // name
|
||||
address.host.to_string(), // host
|
||||
address.port.to_string(), // port
|
||||
database_name.to_string(), // database
|
||||
config.user.name.to_string(), // force_user
|
||||
config.general.pool_size.to_string(), // pool_size
|
||||
"0".to_string(), // min_pool_size
|
||||
"0".to_string(), // reserve_pool
|
||||
config.general.pool_mode.to_string(), // pool_mode
|
||||
config.general.pool_size.to_string(), // max_connections
|
||||
pool_state.connections.to_string(), // current_connections
|
||||
"0".to_string(), // paused
|
||||
"0".to_string(), // disabled
|
||||
]));
|
||||
}
|
||||
}
|
||||
|
||||
res.put(command_complete("SHOW"));
|
||||
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, res).await
|
||||
}
|
||||
|
||||
/// Ignore any SET commands the client sends.
|
||||
/// This is common initialization done by ORMs.
|
||||
async fn ignore_set(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
|
||||
custom_protocol_response_ok(stream, "SET").await
|
||||
}
|
||||
|
||||
/// RELOAD
|
||||
async fn reload(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
|
||||
info!("Reloading config");
|
||||
|
||||
let config = get_config();
|
||||
let path = config.path.clone().unwrap();
|
||||
|
||||
parse(&path).await?;
|
||||
|
||||
let config = get_config();
|
||||
|
||||
config.show();
|
||||
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
// CommandComplete
|
||||
res.put(command_complete("RELOAD"));
|
||||
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, res).await
|
||||
}
|
||||
|
||||
async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
|
||||
let guard = get_config();
|
||||
let config = &*guard.clone();
|
||||
let config: HashMap<String, String> = config.into();
|
||||
drop(guard);
|
||||
|
||||
// Configs that cannot be changed dynamically.
|
||||
let immutables = ["host", "port", "connect_timeout"];
|
||||
|
||||
// Columns
|
||||
let columns = vec![
|
||||
("key", DataType::Text),
|
||||
("value", DataType::Text),
|
||||
("default", DataType::Text),
|
||||
("changeable", DataType::Text),
|
||||
];
|
||||
|
||||
// Response data
|
||||
let mut res = BytesMut::new();
|
||||
res.put(row_description(&columns));
|
||||
|
||||
// DataRow rows
|
||||
for (key, value) in config {
|
||||
let changeable = if immutables.iter().filter(|col| *col == &key).count() == 1 {
|
||||
"no".to_string()
|
||||
} else {
|
||||
"yes".to_string()
|
||||
};
|
||||
|
||||
let row = vec![key, value, "-".to_string(), changeable];
|
||||
|
||||
res.put(data_row(&row));
|
||||
}
|
||||
|
||||
res.put(command_complete("SHOW"));
|
||||
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, res).await
|
||||
}
|
||||
|
||||
/// SHOW STATS
|
||||
async fn show_stats(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
|
||||
let columns = vec![
|
||||
("database", DataType::Text),
|
||||
("total_xact_count", DataType::Numeric),
|
||||
("total_query_count", DataType::Numeric),
|
||||
("total_received", DataType::Numeric),
|
||||
("total_sent", DataType::Numeric),
|
||||
("total_xact_time", DataType::Numeric),
|
||||
("total_query_time", DataType::Numeric),
|
||||
("total_wait_time", DataType::Numeric),
|
||||
("avg_xact_count", DataType::Numeric),
|
||||
("avg_query_count", DataType::Numeric),
|
||||
("avg_recv", DataType::Numeric),
|
||||
("avg_sent", DataType::Numeric),
|
||||
("avg_xact_time", DataType::Numeric),
|
||||
("avg_query_time", DataType::Numeric),
|
||||
("avg_wait_time", DataType::Numeric),
|
||||
];
|
||||
|
||||
let stats = get_stats();
|
||||
let mut res = BytesMut::new();
|
||||
res.put(row_description(&columns));
|
||||
|
||||
let mut row = vec![
|
||||
String::from("all"), // TODO: per-database stats,
|
||||
];
|
||||
|
||||
for column in &columns[1..] {
|
||||
row.push(stats.get(column.0).unwrap_or(&0).to_string());
|
||||
}
|
||||
|
||||
res.put(data_row(&row));
|
||||
res.put(command_complete("SHOW"));
|
||||
|
||||
res.put_u8(b'Z');
|
||||
res.put_i32(5);
|
||||
res.put_u8(b'I');
|
||||
|
||||
write_all_half(stream, res).await
|
||||
}
|
||||
@@ -11,6 +11,7 @@ use tokio::net::{
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::admin::handle_admin;
|
||||
use crate::config::get_config;
|
||||
use crate::constants::*;
|
||||
use crate::errors::Error;
|
||||
@@ -54,6 +55,9 @@ pub struct Client {
|
||||
|
||||
// Statistics
|
||||
stats: Reporter,
|
||||
|
||||
// Clients want to talk to admin
|
||||
admin: bool,
|
||||
}
|
||||
|
||||
impl Client {
|
||||
@@ -104,20 +108,32 @@ impl Client {
|
||||
// Regular startup message.
|
||||
PROTOCOL_VERSION_NUMBER => {
|
||||
trace!("Got StartupMessage");
|
||||
|
||||
// TODO: perform actual auth.
|
||||
let parameters = parse_startup(bytes.clone())?;
|
||||
let mut user_name: String = String::new();
|
||||
match parameters.get(&"user") {
|
||||
Some(&user) => user_name = user,
|
||||
None => return Err(Error::ClientBadStartup),
|
||||
}
|
||||
start_auth(&mut stream, &user_name).await?;
|
||||
|
||||
// Generate random backend ID and secret key
|
||||
let process_id: i32 = rand::random();
|
||||
let secret_key: i32 = rand::random();
|
||||
|
||||
auth_ok(&mut stream).await?;
|
||||
write_all(&mut stream, server_info).await?;
|
||||
backend_key_data(&mut stream, process_id, secret_key).await?;
|
||||
ready_for_query(&mut stream).await?;
|
||||
trace!("Startup OK");
|
||||
|
||||
let database = parameters
|
||||
.get("database")
|
||||
.unwrap_or(parameters.get("user").unwrap());
|
||||
let admin = ["pgcat", "pgbouncer"]
|
||||
.iter()
|
||||
.filter(|db| *db == &database)
|
||||
.count()
|
||||
== 1;
|
||||
|
||||
// Split the read and write streams
|
||||
// so we can control buffering.
|
||||
let (read, write) = stream.into_split();
|
||||
@@ -133,6 +149,7 @@ impl Client {
|
||||
client_server_map: client_server_map,
|
||||
parameters: parameters,
|
||||
stats: stats,
|
||||
admin: admin,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -154,6 +171,7 @@ impl Client {
|
||||
client_server_map: client_server_map,
|
||||
parameters: HashMap::new(),
|
||||
stats: stats,
|
||||
admin: false,
|
||||
});
|
||||
}
|
||||
|
||||
@@ -220,6 +238,13 @@ impl Client {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Handle admin database real quick
|
||||
if self.admin {
|
||||
trace!("Handling admin command");
|
||||
handle_admin(&mut self.write, message, pool.clone()).await?;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle all custom protocol commands here.
|
||||
match query_router.try_execute_command(message.clone()) {
|
||||
// Normal query
|
||||
|
||||
@@ -19,6 +19,15 @@ pub enum Role {
|
||||
Replica,
|
||||
}
|
||||
|
||||
impl ToString for Role {
|
||||
fn to_string(&self) -> String {
|
||||
match *self {
|
||||
Role::Primary => "primary".to_string(),
|
||||
Role::Replica => "replica".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<Option<Role>> for Role {
|
||||
fn eq(&self, other: &Option<Role>) -> bool {
|
||||
match other {
|
||||
@@ -43,6 +52,7 @@ pub struct Address {
|
||||
pub port: String,
|
||||
pub shard: usize,
|
||||
pub role: Role,
|
||||
pub replica_number: usize,
|
||||
}
|
||||
|
||||
impl Default for Address {
|
||||
@@ -51,11 +61,22 @@ impl Default for Address {
|
||||
host: String::from("127.0.0.1"),
|
||||
port: String::from("5432"),
|
||||
shard: 0,
|
||||
replica_number: 0,
|
||||
role: Role::Replica,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Address {
|
||||
pub fn name(&self) -> String {
|
||||
match self.role {
|
||||
Role::Primary => format!("shard_{}_primary", self.shard),
|
||||
|
||||
Role::Replica => format!("shard_{}_replica_{}", self.shard, self.replica_number),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Deserialize, Debug)]
|
||||
pub struct User {
|
||||
pub name: String,
|
||||
@@ -134,6 +155,7 @@ impl Default for QueryRouter {
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub path: Option<String>,
|
||||
pub general: General,
|
||||
pub user: User,
|
||||
pub shards: HashMap<String, Shard>,
|
||||
@@ -143,6 +165,7 @@ pub struct Config {
|
||||
impl Default for Config {
|
||||
fn default() -> Config {
|
||||
Config {
|
||||
path: Some(String::from("pgcat.toml")),
|
||||
general: General::default(),
|
||||
user: User::default(),
|
||||
shards: HashMap::from([(String::from("1"), Shard::default())]),
|
||||
@@ -151,6 +174,52 @@ impl Default for Config {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&Config> for std::collections::HashMap<String, String> {
|
||||
fn from(config: &Config) -> HashMap<String, String> {
|
||||
HashMap::from([
|
||||
("host".to_string(), config.general.host.to_string()),
|
||||
("port".to_string(), config.general.port.to_string()),
|
||||
(
|
||||
"pool_size".to_string(),
|
||||
config.general.pool_size.to_string(),
|
||||
),
|
||||
(
|
||||
"pool_mode".to_string(),
|
||||
config.general.pool_mode.to_string(),
|
||||
),
|
||||
(
|
||||
"connect_timeout".to_string(),
|
||||
config.general.connect_timeout.to_string(),
|
||||
),
|
||||
(
|
||||
"healthcheck_timeout".to_string(),
|
||||
config.general.healthcheck_timeout.to_string(),
|
||||
),
|
||||
("ban_time".to_string(), config.general.ban_time.to_string()),
|
||||
(
|
||||
"statsd_address".to_string(),
|
||||
config.general.statsd_address.to_string(),
|
||||
),
|
||||
(
|
||||
"default_role".to_string(),
|
||||
config.query_router.default_role.to_string(),
|
||||
),
|
||||
(
|
||||
"query_parser_enabled".to_string(),
|
||||
config.query_router.query_parser_enabled.to_string(),
|
||||
),
|
||||
(
|
||||
"primary_reads_enabled".to_string(),
|
||||
config.query_router.primary_reads_enabled.to_string(),
|
||||
),
|
||||
(
|
||||
"sharding_function".to_string(),
|
||||
config.query_router.sharding_function.to_string(),
|
||||
),
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn show(&self) {
|
||||
info!("Pool size: {}", self.general.pool_size);
|
||||
@@ -189,7 +258,7 @@ pub async fn parse(path: &str) -> Result<(), Error> {
|
||||
}
|
||||
};
|
||||
|
||||
let config: Config = match toml::from_str(&contents) {
|
||||
let mut config: Config = match toml::from_str(&contents) {
|
||||
Ok(config) => config,
|
||||
Err(err) => {
|
||||
error!("Could not parse config file: {}", err.to_string());
|
||||
@@ -279,6 +348,8 @@ pub async fn parse(path: &str) -> Result<(), Error> {
|
||||
}
|
||||
};
|
||||
|
||||
config.path = Some(path.to_string());
|
||||
|
||||
CONFIG.store(Arc::new(config.clone()));
|
||||
|
||||
Ok(())
|
||||
@@ -296,5 +367,6 @@ mod test {
|
||||
assert_eq!(get_config().shards["1"].servers[0].0, "127.0.0.1");
|
||||
assert_eq!(get_config().shards["0"].servers[0].2, "primary");
|
||||
assert_eq!(get_config().query_router.default_role, "any");
|
||||
assert_eq!(get_config().path, Some("pgcat.toml".to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,3 +20,8 @@ pub const AUTHENTICATION_SUCCESSFUL: i32 = 0;
|
||||
|
||||
// ErrorResponse: A code identifying the field type; if zero, this is the message terminator and no string follows.
|
||||
pub const MESSAGE_TERMINATOR: u8 = 0;
|
||||
|
||||
//
|
||||
// Data types
|
||||
//
|
||||
pub const _OID_INT8: i32 = 20; // bigint
|
||||
|
||||
@@ -8,5 +8,7 @@ pub enum Error {
|
||||
// ServerTimeout,
|
||||
// DirtyServer,
|
||||
BadConfig,
|
||||
BadUserList,
|
||||
AllServersDown,
|
||||
AuthenticationError
|
||||
}
|
||||
|
||||
11
src/main.rs
11
src/main.rs
@@ -47,8 +47,10 @@ use tokio::{
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
mod admin;
|
||||
mod client;
|
||||
mod config;
|
||||
mod userlist;
|
||||
mod constants;
|
||||
mod errors;
|
||||
mod messages;
|
||||
@@ -93,6 +95,15 @@ async fn main() {
|
||||
}
|
||||
};
|
||||
|
||||
// Prepare user list
|
||||
match userlist::parse("userlist.json").await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("Userlist parse error: {:?}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let config = get_config();
|
||||
|
||||
let addr = format!("{}:{}", config.general.host, config.general.port);
|
||||
|
||||
264
src/messages.rs
264
src/messages.rs
@@ -1,5 +1,7 @@
|
||||
/// Helper functions to send one-off protocol messages
|
||||
/// and handle TcpStream (TCP socket).
|
||||
|
||||
|
||||
use bytes::{Buf, BufMut, BytesMut};
|
||||
use md5::{Digest, Md5};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
|
||||
@@ -7,10 +9,125 @@ use tokio::net::{
|
||||
tcp::{OwnedReadHalf, OwnedWriteHalf},
|
||||
TcpStream,
|
||||
};
|
||||
|
||||
use std::collections::HashMap;
|
||||
use log::{error};
|
||||
|
||||
use crate::errors::Error;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use rand::Rng;
|
||||
|
||||
use crate::userlist::get_user_list;
|
||||
|
||||
|
||||
/// Postgres data type mappings
|
||||
/// used in RowDescription ('T') message.
|
||||
pub enum DataType {
|
||||
Text,
|
||||
Int4,
|
||||
Numeric,
|
||||
}
|
||||
|
||||
impl From<&DataType> for i32 {
|
||||
fn from(data_type: &DataType) -> i32 {
|
||||
match data_type {
|
||||
DataType::Text => 25,
|
||||
DataType::Int4 => 23,
|
||||
DataType::Numeric => 1700,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
1. Generate salt (4 bytes of random data)
|
||||
md5(concat(md5(concat(password, username)), random-salt)))
|
||||
2. Send md5 auth request
|
||||
3. recieve PasswordMessage with salt.
|
||||
4. refactor md5_password function to be reusable
|
||||
5. check username hash combo against file
|
||||
6. AuthenticationOk or ErrorResponse
|
||||
**/
|
||||
pub async fn start_auth(stream: &mut TcpStream, user_name: &String) -> Result<(), Error> {
|
||||
let mut rng = rand::thread_rng();
|
||||
|
||||
//Generate random 4 byte salt
|
||||
let salt = rng.gen::<u32>();
|
||||
|
||||
// Send AuthenticationMD5Password request
|
||||
send_md5_request(stream, salt).await?;
|
||||
|
||||
let code = match stream.read_u8().await {
|
||||
Ok(code) => code as char,
|
||||
Err(_) => return Err(Error::AuthenticationError),
|
||||
};
|
||||
|
||||
match code {
|
||||
// Password response
|
||||
'p' => {
|
||||
fetch_password_and_authenticate(stream, &user_name, &salt).await?;
|
||||
Ok(auth_ok(stream).await?)
|
||||
}
|
||||
_ => {
|
||||
error!("Unknown code: {}", code);
|
||||
return Err(Error::AuthenticationError);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_md5_request(stream: &mut TcpStream, salt: u32) -> Result<(), Error> {
|
||||
let mut authentication_md5password = BytesMut::with_capacity(12);
|
||||
authentication_md5password.put_u8(b'R');
|
||||
authentication_md5password.put_i32(12);
|
||||
authentication_md5password.put_i32(5);
|
||||
authentication_md5password.put_u32(salt);
|
||||
|
||||
// Send AuthenticationMD5Password request
|
||||
Ok(write_all(stream, authentication_md5password).await?)
|
||||
}
|
||||
|
||||
pub async fn fetch_password_and_authenticate(stream: &mut TcpStream, user_name: &String, salt: &u32) -> Result<(), Error> {
|
||||
/**
|
||||
1. How do I store the lists of users and paswords? clear text or hash?? wtf
|
||||
2. Add auth to tests
|
||||
**/
|
||||
|
||||
let len = match stream.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => return Err(Error::AuthenticationError),
|
||||
};
|
||||
|
||||
// Read whatever is left.
|
||||
let mut password_hash = vec![0u8; len as usize - 4];
|
||||
|
||||
match stream.read_exact(&mut password_hash).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => return Err(Error::AuthenticationError),
|
||||
};
|
||||
|
||||
let user_list = get_user_list();
|
||||
let mut password: String = String::new();
|
||||
match user_list.get(&user_name) {
|
||||
Some(&p) => password = p,
|
||||
None => return Err(Error::AuthenticationError),
|
||||
}
|
||||
|
||||
let mut md5 = Md5::new();
|
||||
|
||||
// concat('md5', md5(concat(md5(concat(password, username)), random-salt)))
|
||||
// First pass
|
||||
md5.update(&password.as_bytes());
|
||||
md5.update(&user_name.as_bytes());
|
||||
let output = md5.finalize_reset();
|
||||
// Second pass
|
||||
md5.update(format!("{:x}", output));
|
||||
md5.update(salt.to_be_bytes().to_vec());
|
||||
|
||||
|
||||
let password_string: String = String::from_utf8(password_hash).expect("Could not get password hash");
|
||||
match format!("md5{:x}", md5.finalize()) == password_string {
|
||||
true => Ok(()),
|
||||
_ => Err(Error::AuthenticationError)
|
||||
}
|
||||
}
|
||||
|
||||
/// Tell the client that authentication handshake completed successfully.
|
||||
pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
|
||||
@@ -91,9 +208,8 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse StartupMessage parameters.
|
||||
/// e.g. user, database, application_name, etc.
|
||||
pub fn parse_startup(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
|
||||
/// Parse the params the server sends as a key/value format.
|
||||
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
|
||||
let mut result = HashMap::new();
|
||||
let mut buf = Vec::new();
|
||||
let mut tmp = String::new();
|
||||
@@ -115,7 +231,7 @@ pub fn parse_startup(mut bytes: BytesMut) -> Result<HashMap<String, String>, Err
|
||||
|
||||
// Expect pairs of name and value
|
||||
// and at least one pair to be present.
|
||||
if buf.len() % 2 != 0 && buf.len() >= 2 {
|
||||
if buf.len() % 2 != 0 || buf.len() < 2 {
|
||||
return Err(Error::ClientBadStartup);
|
||||
}
|
||||
|
||||
@@ -127,6 +243,14 @@ pub fn parse_startup(mut bytes: BytesMut) -> Result<HashMap<String, String>, Err
|
||||
i += 2;
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Parse StartupMessage parameters.
|
||||
/// e.g. user, database, application_name, etc.
|
||||
pub fn parse_startup(bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
|
||||
let result = parse_params(bytes)?;
|
||||
|
||||
// Minimum required parameters
|
||||
// I want to have the user at the very minimum, according to the protocol spec.
|
||||
if !result.contains_key("user") {
|
||||
@@ -252,68 +376,17 @@ pub async fn show_response(
|
||||
// 3. CommandComplete
|
||||
// 4. ReadyForQuery
|
||||
|
||||
// RowDescription
|
||||
let mut row_desc = BytesMut::new();
|
||||
|
||||
// Number of columns: 1
|
||||
row_desc.put_i16(1);
|
||||
|
||||
// Column name
|
||||
row_desc.put_slice(&format!("{}\0", name).as_bytes());
|
||||
|
||||
// Doesn't belong to any table
|
||||
row_desc.put_i32(0);
|
||||
|
||||
// Doesn't belong to any table
|
||||
row_desc.put_i16(0);
|
||||
|
||||
// Text
|
||||
row_desc.put_i32(25);
|
||||
|
||||
// Text size = variable (-1)
|
||||
row_desc.put_i16(-1);
|
||||
|
||||
// Type modifier: none that I know
|
||||
row_desc.put_i32(0);
|
||||
|
||||
// Format being used: text (0), binary (1)
|
||||
row_desc.put_i16(0);
|
||||
|
||||
// DataRow
|
||||
let mut data_row = BytesMut::new();
|
||||
|
||||
// Number of columns
|
||||
data_row.put_i16(1);
|
||||
|
||||
// Size of the column content (length of the string really)
|
||||
data_row.put_i32(value.len() as i32);
|
||||
|
||||
// The content
|
||||
data_row.put_slice(value.as_bytes());
|
||||
|
||||
// CommandComplete
|
||||
let mut command_complete = BytesMut::new();
|
||||
|
||||
// Number of rows returned (just one)
|
||||
command_complete.put_slice(&b"SELECT 1\0"[..]);
|
||||
|
||||
// The final messages sent to the client
|
||||
let mut res = BytesMut::new();
|
||||
|
||||
// RowDescription
|
||||
res.put_u8(b'T');
|
||||
res.put_i32(row_desc.len() as i32 + 4);
|
||||
res.put(row_desc);
|
||||
res.put(row_description(&vec![(name, DataType::Text)]));
|
||||
|
||||
// DataRow
|
||||
res.put_u8(b'D');
|
||||
res.put_i32(data_row.len() as i32 + 4);
|
||||
res.put(data_row);
|
||||
res.put(data_row(&vec![value.to_string()]));
|
||||
|
||||
// CommandComplete
|
||||
res.put_u8(b'C');
|
||||
res.put_i32(command_complete.len() as i32 + 4);
|
||||
res.put(command_complete);
|
||||
res.put(command_complete("SELECT 1"));
|
||||
|
||||
// ReadyForQuery
|
||||
res.put_u8(b'Z');
|
||||
@@ -323,6 +396,77 @@ pub async fn show_response(
|
||||
write_all_half(stream, res).await
|
||||
}
|
||||
|
||||
pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
|
||||
let mut res = BytesMut::new();
|
||||
let mut row_desc = BytesMut::new();
|
||||
|
||||
// how many colums we are storing
|
||||
row_desc.put_i16(columns.len() as i16);
|
||||
|
||||
for (name, data_type) in columns {
|
||||
// Column name
|
||||
row_desc.put_slice(&format!("{}\0", name).as_bytes());
|
||||
|
||||
// Doesn't belong to any table
|
||||
row_desc.put_i32(0);
|
||||
|
||||
// Doesn't belong to any table
|
||||
row_desc.put_i16(0);
|
||||
|
||||
// Text
|
||||
row_desc.put_i32(data_type.into());
|
||||
|
||||
// Text size = variable (-1)
|
||||
let type_size = match data_type {
|
||||
DataType::Text => -1,
|
||||
DataType::Int4 => 4,
|
||||
DataType::Numeric => -1,
|
||||
};
|
||||
|
||||
row_desc.put_i16(type_size);
|
||||
|
||||
// Type modifier: none that I know
|
||||
row_desc.put_i32(-1);
|
||||
|
||||
// Format being used: text (0), binary (1)
|
||||
row_desc.put_i16(0);
|
||||
}
|
||||
|
||||
res.put_u8(b'T');
|
||||
res.put_i32(row_desc.len() as i32 + 4);
|
||||
res.put(row_desc);
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
pub fn data_row(row: &Vec<String>) -> BytesMut {
|
||||
let mut res = BytesMut::new();
|
||||
let mut data_row = BytesMut::new();
|
||||
|
||||
data_row.put_i16(row.len() as i16);
|
||||
|
||||
for column in row {
|
||||
let column = column.as_bytes();
|
||||
data_row.put_i32(column.len() as i32);
|
||||
data_row.put_slice(&column);
|
||||
}
|
||||
|
||||
res.put_u8(b'D');
|
||||
res.put_i32(data_row.len() as i32 + 4);
|
||||
res.put(data_row);
|
||||
|
||||
res
|
||||
}
|
||||
|
||||
pub fn command_complete(command: &str) -> BytesMut {
|
||||
let cmd = BytesMut::from(format!("{}\0", command).as_bytes());
|
||||
let mut res = BytesMut::new();
|
||||
res.put_u8(b'C');
|
||||
res.put_i32(cmd.len() as i32 + 4);
|
||||
res.put(cmd);
|
||||
res
|
||||
}
|
||||
|
||||
/// Write all data in the buffer to the TcpStream.
|
||||
pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Error> {
|
||||
match stream.write_all(&buf).await {
|
||||
|
||||
44
src/pool.rs
44
src/pool.rs
@@ -49,7 +49,8 @@ impl ConnectionPool {
|
||||
for shard_idx in shard_ids {
|
||||
let shard = &config.shards[&shard_idx];
|
||||
let mut pools = Vec::new();
|
||||
let mut replica_addresses = Vec::new();
|
||||
let mut servers = Vec::new();
|
||||
let mut replica_number = 0;
|
||||
|
||||
for server in shard.servers.iter() {
|
||||
let role = match server.2.as_ref() {
|
||||
@@ -65,9 +66,14 @@ impl ConnectionPool {
|
||||
host: server.0.clone(),
|
||||
port: server.1.to_string(),
|
||||
role: role,
|
||||
replica_number,
|
||||
shard: shard_idx.parse::<usize>().unwrap(),
|
||||
};
|
||||
|
||||
if role == Role::Replica {
|
||||
replica_number += 1;
|
||||
}
|
||||
|
||||
let manager = ServerPool::new(
|
||||
address.clone(),
|
||||
config.user.clone(),
|
||||
@@ -87,11 +93,11 @@ impl ConnectionPool {
|
||||
.unwrap();
|
||||
|
||||
pools.push(pool);
|
||||
replica_addresses.push(address);
|
||||
servers.push(address);
|
||||
}
|
||||
|
||||
shards.push(pools);
|
||||
addresses.push(replica_addresses);
|
||||
addresses.push(servers);
|
||||
banlist.push(HashMap::new());
|
||||
}
|
||||
|
||||
@@ -126,10 +132,22 @@ impl ConnectionPool {
|
||||
};
|
||||
|
||||
let mut proxy = connection.0;
|
||||
let _address = connection.1;
|
||||
let address = connection.1;
|
||||
let server = &mut *proxy;
|
||||
|
||||
server_infos.push(server.server_info());
|
||||
let server_info = server.server_info();
|
||||
|
||||
if server_infos.len() > 0 {
|
||||
// Compare against the last server checked.
|
||||
if server_info != server_infos[server_infos.len() - 1] {
|
||||
warn!(
|
||||
"{:?} has different server configuration than the last server",
|
||||
address
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
server_infos.push(server_info);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -324,6 +342,22 @@ impl ConnectionPool {
|
||||
pub fn servers(&self, shard: usize) -> usize {
|
||||
self.addresses[shard].len()
|
||||
}
|
||||
|
||||
pub fn databases(&self) -> usize {
|
||||
let mut databases = 0;
|
||||
for shard in 0..self.shards() {
|
||||
databases += self.servers(shard);
|
||||
}
|
||||
databases
|
||||
}
|
||||
|
||||
pub fn pool_state(&self, shard: usize, server: usize) -> bb8::State {
|
||||
self.databases[shard][server].state()
|
||||
}
|
||||
|
||||
pub fn address(&self, shard: usize, server: usize) -> &Address {
|
||||
&self.addresses[shard][server]
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ServerPool {
|
||||
|
||||
95
src/stats.rs
95
src/stats.rs
@@ -1,12 +1,17 @@
|
||||
use log::info;
|
||||
use log::{debug, info};
|
||||
use once_cell::sync::Lazy;
|
||||
use parking_lot::Mutex;
|
||||
use statsd::Client;
|
||||
/// Events collector and publisher.
|
||||
use tokio::sync::mpsc::{Receiver, Sender};
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::config::get_config;
|
||||
|
||||
// Stats used in SHOW STATS
|
||||
static LATEST_STATS: Lazy<Mutex<HashMap<String, i64>>> = Lazy::new(|| Mutex::new(HashMap::new()));
|
||||
static STAT_PERIOD: u64 = 15000; //15 seconds
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum EventName {
|
||||
CheckoutTime,
|
||||
@@ -182,16 +187,6 @@ impl Reporter {
|
||||
|
||||
let _ = self.tx.try_send(event);
|
||||
}
|
||||
|
||||
// pub fn flush_to_statsd(&self) {
|
||||
// let event = Event {
|
||||
// name: EventName::FlushStatsToStatsD,
|
||||
// value: 0,
|
||||
// process_id: None,
|
||||
// };
|
||||
|
||||
// let _ = self.tx.try_send(event);
|
||||
// }
|
||||
}
|
||||
|
||||
pub struct Collector {
|
||||
@@ -217,7 +212,15 @@ impl Collector {
|
||||
("total_xact_count", 0),
|
||||
("total_sent", 0),
|
||||
("total_received", 0),
|
||||
("total_xact_time", 0),
|
||||
("total_query_time", 0),
|
||||
("total_wait_time", 0),
|
||||
("avg_xact_time", 0),
|
||||
("avg_query_time", 0),
|
||||
("avg_xact_count", 0),
|
||||
("avg_sent", 0),
|
||||
("avg_received", 0),
|
||||
("avg_wait_time", 0),
|
||||
("maxwait_us", 0),
|
||||
("maxwait", 0),
|
||||
("cl_waiting", 0),
|
||||
@@ -229,11 +232,18 @@ impl Collector {
|
||||
("sv_tested", 0),
|
||||
]);
|
||||
|
||||
let mut client_server_states: HashMap<i32, EventName> = HashMap::new();
|
||||
let tx = self.tx.clone();
|
||||
// Stats saved after each iteration of the flush event. Used in calculation
|
||||
// of averages in the last flush period.
|
||||
let mut old_stats: HashMap<String, i64> = HashMap::new();
|
||||
|
||||
// Track which state the client and server are at any given time.
|
||||
let mut client_server_states: HashMap<i32, EventName> = HashMap::new();
|
||||
|
||||
// Flush stats to StatsD and calculate averages every 15 seconds.
|
||||
let tx = self.tx.clone();
|
||||
tokio::task::spawn(async move {
|
||||
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(15000));
|
||||
let mut interval =
|
||||
tokio::time::interval(tokio::time::Duration::from_millis(STAT_PERIOD));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
let _ = tx.try_send(Event {
|
||||
@@ -244,6 +254,7 @@ impl Collector {
|
||||
}
|
||||
});
|
||||
|
||||
// The collector loop
|
||||
loop {
|
||||
let stat = match self.rx.recv().await {
|
||||
Some(stat) => stat,
|
||||
@@ -280,10 +291,11 @@ impl Collector {
|
||||
*counter += stat.value;
|
||||
|
||||
let counter = stats.entry("maxwait_us").or_insert(0);
|
||||
let mic_part = stat.value % 1_000_000;
|
||||
|
||||
// Report max time here
|
||||
if stat.value > *counter {
|
||||
*counter = stat.value;
|
||||
if mic_part > *counter {
|
||||
*counter = mic_part;
|
||||
}
|
||||
|
||||
let counter = stats.entry("maxwait").or_insert(0);
|
||||
@@ -309,6 +321,7 @@ impl Collector {
|
||||
}
|
||||
|
||||
EventName::FlushStatsToStatsD => {
|
||||
// Calculate connection states
|
||||
for (_, state) in &client_server_states {
|
||||
match state {
|
||||
EventName::ClientActive => {
|
||||
@@ -350,13 +363,51 @@ impl Collector {
|
||||
};
|
||||
}
|
||||
|
||||
info!("{:?}", stats);
|
||||
// Calculate averages
|
||||
for stat in &[
|
||||
"avg_query_count",
|
||||
"avgxact_count",
|
||||
"avg_sent",
|
||||
"avg_received",
|
||||
"avg_wait_time",
|
||||
] {
|
||||
let total_name = stat.replace("avg_", "total_");
|
||||
let old_value = old_stats.entry(total_name.clone()).or_insert(0);
|
||||
let new_value = stats.get(total_name.as_str()).unwrap_or(&0).to_owned();
|
||||
let avg = (new_value - *old_value) / (STAT_PERIOD as i64 / 1_000); // Avg / second
|
||||
|
||||
stats.insert(stat, avg);
|
||||
*old_value = new_value;
|
||||
}
|
||||
|
||||
debug!("{:?}", stats);
|
||||
|
||||
// Update latest stats used in SHOW STATS
|
||||
let mut guard = LATEST_STATS.lock();
|
||||
for (key, value) in &stats {
|
||||
guard.insert(key.to_string(), value.clone());
|
||||
}
|
||||
|
||||
let mut pipeline = self.client.pipeline();
|
||||
|
||||
for (key, value) in stats.iter_mut() {
|
||||
for (key, value) in stats.iter() {
|
||||
pipeline.gauge(key, *value as f64);
|
||||
*value = 0;
|
||||
}
|
||||
|
||||
// These are re-calculated every iteration of the loop, so we don't want to add values
|
||||
// from the last iteration.
|
||||
for stat in &[
|
||||
"cl_active",
|
||||
"cl_waiting",
|
||||
"cl_idle",
|
||||
"sv_idle",
|
||||
"sv_active",
|
||||
"sv_tested",
|
||||
"sv_login",
|
||||
"maxwait",
|
||||
"maxwait_us",
|
||||
] {
|
||||
stats.insert(stat, 0);
|
||||
}
|
||||
|
||||
pipeline.send(&self.client);
|
||||
@@ -365,3 +416,7 @@ impl Collector {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_stats() -> HashMap<String, i64> {
|
||||
LATEST_STATS.lock().clone()
|
||||
}
|
||||
|
||||
4
src/userlist.json
Normal file
4
src/userlist.json
Normal file
@@ -0,0 +1,4 @@
|
||||
{
|
||||
"sven": "clear_text_password",
|
||||
"sharding_user": "sharding_user"
|
||||
}
|
||||
57
src/userlist.rs
Normal file
57
src/userlist.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use arc_swap::{ArcSwap, Guard};
|
||||
use log::{error};
|
||||
use once_cell::sync::Lazy;
|
||||
use tokio::fs::File;
|
||||
use tokio::io::AsyncReadExt;
|
||||
|
||||
use std::collections::{HashMap};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::errors::Error;
|
||||
|
||||
pub type UserList = HashMap<String, String>;
|
||||
static USER_LIST: Lazy<ArcSwap<UserList>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::new()));
|
||||
|
||||
pub fn get_user_list() -> Guard<Arc<UserList>> {
|
||||
USER_LIST.load()
|
||||
}
|
||||
|
||||
/// Parse the user list.
|
||||
pub async fn parse(path: &str) -> Result<(), Error> {
|
||||
let mut contents = String::new();
|
||||
let mut file = match File::open(path).await {
|
||||
Ok(file) => file,
|
||||
Err(err) => {
|
||||
error!("Could not open '{}': {}", path, err.to_string());
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
};
|
||||
|
||||
match file.read_to_string(&mut contents).await {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
error!("Could not read config file: {}", err.to_string());
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
};
|
||||
|
||||
let map: HashMap<String, String> = serde_json::from_str(&contents).expect("JSON was not well-formatted");
|
||||
|
||||
|
||||
|
||||
USER_LIST.store(Arc::new(map.clone()));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_config() {
|
||||
parse("userlist.json").await.unwrap();
|
||||
assert_eq!(get_user_list()["sven"], "clear_text_password");
|
||||
assert_eq!(get_user_list()["sharding_user"], "sharding_user");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user