Compare commits

..

11 Commits

Author SHA1 Message Date
Sven Elfgren
5f745859c0 Started on md5 auth.
Left to figure out:
Whats the right format to store user:pw in userlist?
hashmap errors?
actually do hash:x comparison
2022-03-04 12:11:17 -07:00
Lev Kokotov
1e8fa110ae Fix pgbouncerhero (#54) 2022-03-02 14:46:31 -08:00
Lev Kokotov
d4186b7815 More admin (#53)
* more admin

* more admin

* show lists

* tests
2022-03-01 22:49:43 -08:00
Lev Kokotov
aaeef69d59 Refactor admin (#52) 2022-03-01 08:47:19 -08:00
Lev Kokotov
b21e0f4a7e admin SHOW DATABASES (#51)
* admin SHOW DATABASES

* test

* correct replica count
2022-02-28 17:22:28 -08:00
Lev Kokotov
eb1473060e admin: SHOW CONFIG (#50)
* admin: SHOW CONFIG

* test
2022-02-28 08:14:39 -08:00
Lev Kokotov
26f75f8d5d admin RELOAD (#49)
* admin RELOAD

* test
2022-02-27 10:21:24 -08:00
Lev Kokotov
99d65fc475 Check server versions on startup & refactor (#48)
* Refactor and check server parameters

* warnings

* fix validator
2022-02-26 11:01:52 -08:00
Lev Kokotov
206fdc9769 Fix some stats (#47)
* fix some stats

* use constant

* lint
2022-02-26 10:03:11 -08:00
Lev Kokotov
f74101cdfe admin: SHOW STATS (#46)
* admin: show stats

* warning

* tests

* lint

* type mod
2022-02-25 18:20:15 -08:00
Lev Kokotov
8e0682482d query routing docs (#45) 2022-02-25 14:27:33 -08:00
16 changed files with 894 additions and 95 deletions

View File

@@ -57,6 +57,17 @@ cd tests/ruby && \
ruby tests.rb && \
cd ../..
# Admin tests
psql -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW LISTS' > /dev/null
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW POOLS' > /dev/null
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW VERSION' > /dev/null
psql -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev/null # will ignore
(! psql -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null)
# Start PgCat in debug to demonstrate failover better
start_pgcat "debug"

3
.gitignore vendored
View File

@@ -1,2 +1,5 @@
/target
*.deb
.idea/*
tests/ruby/.bundle/*
tests/ruby/vendor/*

24
Cargo.lock generated
View File

@@ -220,6 +220,12 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "itoa"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1aab8fc367588b89dcee83ab0fd66b72b50b72fa1904d7095045ace2b0c81c35"
[[package]]
name = "libc"
version = "0.2.117"
@@ -369,6 +375,7 @@ dependencies = [
"regex",
"serde",
"serde_derive",
"serde_json",
"sha-1",
"sqlparser",
"statsd",
@@ -478,6 +485,12 @@ version = "0.6.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f497285884f3fcff424ffc933e56d7cbca511def0c9831a7f9b5f6153e3cc89b"
[[package]]
name = "ryu"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73b4b750c782965c211b42f022f59af1fbceabdd026623714f104152f1ec149f"
[[package]]
name = "scopeguard"
version = "1.1.0"
@@ -501,6 +514,17 @@ dependencies = [
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e8d9fa5c3b304765ce1fd9c4c8a3de2c8db365a5b91be52f186efc675681d95"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]]
name = "sha-1"
version = "0.10.0"

View File

@@ -17,6 +17,7 @@ sha-1 = "0.10"
toml = "0.5"
serde = "1"
serde_derive = "1"
serde_json = "1"
regex = "1"
num_cpus = "1"
once_cell = "1"

View File

@@ -48,8 +48,8 @@ password = "sharding_user"
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
["127.0.0.1", 5432, "primary"],
["localhost", 5432, "replica"],
# [ "127.0.1.1", 5432, "replica" ],
]
# Database name (e.g. "postgres")
@@ -58,8 +58,8 @@ database = "shard0"
[shards.1]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
["127.0.0.1", 5432, "primary"],
["localhost", 5432, "replica"],
# [ "127.0.1.1", 5432, "replica" ],
]
database = "shard1"
@@ -67,8 +67,8 @@ database = "shard1"
[shards.2]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
["127.0.0.1", 5432, "primary"],
["localhost", 5432, "replica"],
# [ "127.0.1.1", 5432, "replica" ],
]
database = "shard2"

351
src/admin.rs Normal file
View File

@@ -0,0 +1,351 @@
use bytes::{Buf, BufMut, BytesMut};
use log::{info, trace};
use tokio::net::tcp::OwnedWriteHalf;
use std::collections::HashMap;
use crate::config::{get_config, parse};
use crate::errors::Error;
use crate::messages::*;
use crate::pool::ConnectionPool;
use crate::stats::get_stats;
/// Handle admin client
pub async fn handle_admin(
stream: &mut OwnedWriteHalf,
mut query: BytesMut,
pool: ConnectionPool,
) -> Result<(), Error> {
let code = query.get_u8() as char;
if code != 'Q' {
return Err(Error::ProtocolSyncError);
}
let len = query.get_i32() as usize;
let query = String::from_utf8_lossy(&query[..len - 5])
.to_string()
.to_ascii_uppercase();
trace!("Admin query: {}", query);
if query.starts_with("SHOW STATS") {
trace!("SHOW STATS");
show_stats(stream).await
} else if query.starts_with("RELOAD") {
trace!("RELOAD");
reload(stream).await
} else if query.starts_with("SHOW CONFIG") {
trace!("SHOW CONFIG");
show_config(stream).await
} else if query.starts_with("SHOW DATABASES") {
trace!("SHOW DATABASES");
show_databases(stream, &pool).await
} else if query.starts_with("SHOW POOLS") {
trace!("SHOW POOLS");
show_pools(stream, &pool).await
} else if query.starts_with("SHOW LISTS") {
trace!("SHOW LISTS");
show_lists(stream, &pool).await
} else if query.starts_with("SHOW VERSION") {
trace!("SHOW VERSION");
show_version(stream).await
} else if query.starts_with("SET ") {
trace!("SET");
ignore_set(stream).await
} else {
error_response(stream, "Unsupported query against the admin database").await
}
}
/// SHOW LISTS
async fn show_lists(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> {
let stats = get_stats();
let columns = vec![("list", DataType::Text), ("items", DataType::Int4)];
let mut res = BytesMut::new();
res.put(row_description(&columns));
res.put(data_row(&vec![
"databases".to_string(),
(pool.databases() + 1).to_string(), // see comment below
]));
res.put(data_row(&vec!["users".to_string(), "1".to_string()]));
res.put(data_row(&vec![
"pools".to_string(),
(pool.databases() + 1).to_string(), // +1 for the pgbouncer admin db pool which isn't real
])); // but admin tools that work with pgbouncer want this
res.put(data_row(&vec![
"free_clients".to_string(),
stats["cl_idle"].to_string(),
]));
res.put(data_row(&vec![
"used_clients".to_string(),
stats["cl_active"].to_string(),
]));
res.put(data_row(&vec![
"login_clients".to_string(),
"0".to_string(),
]));
res.put(data_row(&vec![
"free_servers".to_string(),
stats["sv_idle"].to_string(),
]));
res.put(data_row(&vec![
"used_servers".to_string(),
stats["sv_active"].to_string(),
]));
res.put(data_row(&vec!["dns_names".to_string(), "0".to_string()]));
res.put(data_row(&vec!["dns_zones".to_string(), "0".to_string()]));
res.put(data_row(&vec!["dns_queries".to_string(), "0".to_string()]));
res.put(data_row(&vec!["dns_pending".to_string(), "0".to_string()]));
res.put(command_complete("SHOW"));
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
}
/// SHOW VERSION
async fn show_version(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
let mut res = BytesMut::new();
res.put(row_description(&vec![("version", DataType::Text)]));
res.put(data_row(&vec!["PgCat 0.1.0".to_string()]));
res.put(command_complete("SHOW"));
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
}
/// SHOW POOLS
async fn show_pools(stream: &mut OwnedWriteHalf, _pool: &ConnectionPool) -> Result<(), Error> {
let stats = get_stats();
let config = {
let guard = get_config();
&*guard.clone()
};
let columns = vec![
("database", DataType::Text),
("user", DataType::Text),
("cl_active", DataType::Numeric),
("cl_waiting", DataType::Numeric),
("cl_cancel_req", DataType::Numeric),
("sv_active", DataType::Numeric),
("sv_idle", DataType::Numeric),
("sv_used", DataType::Numeric),
("sv_tested", DataType::Numeric),
("sv_login", DataType::Numeric),
("maxwait", DataType::Numeric),
("maxwait_us", DataType::Numeric),
("pool_mode", DataType::Text),
];
let mut res = BytesMut::new();
res.put(row_description(&columns));
let mut row = vec![String::from("all"), config.user.name.clone()];
for column in &columns[2..columns.len() - 1] {
let value = stats.get(column.0).unwrap_or(&0).to_string();
row.push(value);
}
row.push(config.general.pool_mode.to_string());
res.put(data_row(&row));
res.put(command_complete("SHOW"));
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
}
/// SHOW DATABASES
async fn show_databases(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> {
let guard = get_config();
let config = &*guard.clone();
drop(guard);
// Columns
let columns = vec![
("name", DataType::Text),
("host", DataType::Text),
("port", DataType::Text),
("database", DataType::Text),
("force_user", DataType::Text),
("pool_size", DataType::Int4),
("min_pool_size", DataType::Int4),
("reserve_pool", DataType::Int4),
("pool_mode", DataType::Text),
("max_connections", DataType::Int4),
("current_connections", DataType::Int4),
("paused", DataType::Int4),
("disabled", DataType::Int4),
];
let mut res = BytesMut::new();
// RowDescription
res.put(row_description(&columns));
for shard in 0..pool.shards() {
let database_name = &config.shards[&shard.to_string()].database;
for server in 0..pool.servers(shard) {
let address = pool.address(shard, server);
let pool_state = pool.pool_state(shard, server);
res.put(data_row(&vec![
address.name(), // name
address.host.to_string(), // host
address.port.to_string(), // port
database_name.to_string(), // database
config.user.name.to_string(), // force_user
config.general.pool_size.to_string(), // pool_size
"0".to_string(), // min_pool_size
"0".to_string(), // reserve_pool
config.general.pool_mode.to_string(), // pool_mode
config.general.pool_size.to_string(), // max_connections
pool_state.connections.to_string(), // current_connections
"0".to_string(), // paused
"0".to_string(), // disabled
]));
}
}
res.put(command_complete("SHOW"));
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
}
/// Ignore any SET commands the client sends.
/// This is common initialization done by ORMs.
async fn ignore_set(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
custom_protocol_response_ok(stream, "SET").await
}
/// RELOAD
async fn reload(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
info!("Reloading config");
let config = get_config();
let path = config.path.clone().unwrap();
parse(&path).await?;
let config = get_config();
config.show();
let mut res = BytesMut::new();
// CommandComplete
res.put(command_complete("RELOAD"));
// ReadyForQuery
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
}
async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
let guard = get_config();
let config = &*guard.clone();
let config: HashMap<String, String> = config.into();
drop(guard);
// Configs that cannot be changed dynamically.
let immutables = ["host", "port", "connect_timeout"];
// Columns
let columns = vec![
("key", DataType::Text),
("value", DataType::Text),
("default", DataType::Text),
("changeable", DataType::Text),
];
// Response data
let mut res = BytesMut::new();
res.put(row_description(&columns));
// DataRow rows
for (key, value) in config {
let changeable = if immutables.iter().filter(|col| *col == &key).count() == 1 {
"no".to_string()
} else {
"yes".to_string()
};
let row = vec![key, value, "-".to_string(), changeable];
res.put(data_row(&row));
}
res.put(command_complete("SHOW"));
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
}
/// SHOW STATS
async fn show_stats(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
let columns = vec![
("database", DataType::Text),
("total_xact_count", DataType::Numeric),
("total_query_count", DataType::Numeric),
("total_received", DataType::Numeric),
("total_sent", DataType::Numeric),
("total_xact_time", DataType::Numeric),
("total_query_time", DataType::Numeric),
("total_wait_time", DataType::Numeric),
("avg_xact_count", DataType::Numeric),
("avg_query_count", DataType::Numeric),
("avg_recv", DataType::Numeric),
("avg_sent", DataType::Numeric),
("avg_xact_time", DataType::Numeric),
("avg_query_time", DataType::Numeric),
("avg_wait_time", DataType::Numeric),
];
let stats = get_stats();
let mut res = BytesMut::new();
res.put(row_description(&columns));
let mut row = vec![
String::from("all"), // TODO: per-database stats,
];
for column in &columns[1..] {
row.push(stats.get(column.0).unwrap_or(&0).to_string());
}
res.put(data_row(&row));
res.put(command_complete("SHOW"));
res.put_u8(b'Z');
res.put_i32(5);
res.put_u8(b'I');
write_all_half(stream, res).await
}

View File

@@ -11,6 +11,7 @@ use tokio::net::{
use std::collections::HashMap;
use crate::admin::handle_admin;
use crate::config::get_config;
use crate::constants::*;
use crate::errors::Error;
@@ -54,6 +55,9 @@ pub struct Client {
// Statistics
stats: Reporter,
// Clients want to talk to admin
admin: bool,
}
impl Client {
@@ -104,20 +108,32 @@ impl Client {
// Regular startup message.
PROTOCOL_VERSION_NUMBER => {
trace!("Got StartupMessage");
// TODO: perform actual auth.
let parameters = parse_startup(bytes.clone())?;
let mut user_name: String = String::new();
match parameters.get(&"user") {
Some(&user) => user_name = user,
None => return Err(Error::ClientBadStartup),
}
start_auth(&mut stream, &user_name).await?;
// Generate random backend ID and secret key
let process_id: i32 = rand::random();
let secret_key: i32 = rand::random();
auth_ok(&mut stream).await?;
write_all(&mut stream, server_info).await?;
backend_key_data(&mut stream, process_id, secret_key).await?;
ready_for_query(&mut stream).await?;
trace!("Startup OK");
let database = parameters
.get("database")
.unwrap_or(parameters.get("user").unwrap());
let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == &database)
.count()
== 1;
// Split the read and write streams
// so we can control buffering.
let (read, write) = stream.into_split();
@@ -133,6 +149,7 @@ impl Client {
client_server_map: client_server_map,
parameters: parameters,
stats: stats,
admin: admin,
});
}
@@ -154,6 +171,7 @@ impl Client {
client_server_map: client_server_map,
parameters: HashMap::new(),
stats: stats,
admin: false,
});
}
@@ -220,6 +238,13 @@ impl Client {
return Ok(());
}
// Handle admin database real quick
if self.admin {
trace!("Handling admin command");
handle_admin(&mut self.write, message, pool.clone()).await?;
continue;
}
// Handle all custom protocol commands here.
match query_router.try_execute_command(message.clone()) {
// Normal query

View File

@@ -19,6 +19,15 @@ pub enum Role {
Replica,
}
impl ToString for Role {
fn to_string(&self) -> String {
match *self {
Role::Primary => "primary".to_string(),
Role::Replica => "replica".to_string(),
}
}
}
impl PartialEq<Option<Role>> for Role {
fn eq(&self, other: &Option<Role>) -> bool {
match other {
@@ -43,6 +52,7 @@ pub struct Address {
pub port: String,
pub shard: usize,
pub role: Role,
pub replica_number: usize,
}
impl Default for Address {
@@ -51,11 +61,22 @@ impl Default for Address {
host: String::from("127.0.0.1"),
port: String::from("5432"),
shard: 0,
replica_number: 0,
role: Role::Replica,
}
}
}
impl Address {
pub fn name(&self) -> String {
match self.role {
Role::Primary => format!("shard_{}_primary", self.shard),
Role::Replica => format!("shard_{}_replica_{}", self.shard, self.replica_number),
}
}
}
#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Deserialize, Debug)]
pub struct User {
pub name: String,
@@ -134,6 +155,7 @@ impl Default for QueryRouter {
#[derive(Deserialize, Debug, Clone)]
pub struct Config {
pub path: Option<String>,
pub general: General,
pub user: User,
pub shards: HashMap<String, Shard>,
@@ -143,6 +165,7 @@ pub struct Config {
impl Default for Config {
fn default() -> Config {
Config {
path: Some(String::from("pgcat.toml")),
general: General::default(),
user: User::default(),
shards: HashMap::from([(String::from("1"), Shard::default())]),
@@ -151,6 +174,52 @@ impl Default for Config {
}
}
impl From<&Config> for std::collections::HashMap<String, String> {
fn from(config: &Config) -> HashMap<String, String> {
HashMap::from([
("host".to_string(), config.general.host.to_string()),
("port".to_string(), config.general.port.to_string()),
(
"pool_size".to_string(),
config.general.pool_size.to_string(),
),
(
"pool_mode".to_string(),
config.general.pool_mode.to_string(),
),
(
"connect_timeout".to_string(),
config.general.connect_timeout.to_string(),
),
(
"healthcheck_timeout".to_string(),
config.general.healthcheck_timeout.to_string(),
),
("ban_time".to_string(), config.general.ban_time.to_string()),
(
"statsd_address".to_string(),
config.general.statsd_address.to_string(),
),
(
"default_role".to_string(),
config.query_router.default_role.to_string(),
),
(
"query_parser_enabled".to_string(),
config.query_router.query_parser_enabled.to_string(),
),
(
"primary_reads_enabled".to_string(),
config.query_router.primary_reads_enabled.to_string(),
),
(
"sharding_function".to_string(),
config.query_router.sharding_function.to_string(),
),
])
}
}
impl Config {
pub fn show(&self) {
info!("Pool size: {}", self.general.pool_size);
@@ -189,7 +258,7 @@ pub async fn parse(path: &str) -> Result<(), Error> {
}
};
let config: Config = match toml::from_str(&contents) {
let mut config: Config = match toml::from_str(&contents) {
Ok(config) => config,
Err(err) => {
error!("Could not parse config file: {}", err.to_string());
@@ -279,6 +348,8 @@ pub async fn parse(path: &str) -> Result<(), Error> {
}
};
config.path = Some(path.to_string());
CONFIG.store(Arc::new(config.clone()));
Ok(())
@@ -296,5 +367,6 @@ mod test {
assert_eq!(get_config().shards["1"].servers[0].0, "127.0.0.1");
assert_eq!(get_config().shards["0"].servers[0].2, "primary");
assert_eq!(get_config().query_router.default_role, "any");
assert_eq!(get_config().path, Some("pgcat.toml".to_string()));
}
}

View File

@@ -20,3 +20,8 @@ pub const AUTHENTICATION_SUCCESSFUL: i32 = 0;
// ErrorResponse: A code identifying the field type; if zero, this is the message terminator and no string follows.
pub const MESSAGE_TERMINATOR: u8 = 0;
//
// Data types
//
pub const _OID_INT8: i32 = 20; // bigint

View File

@@ -8,5 +8,7 @@ pub enum Error {
// ServerTimeout,
// DirtyServer,
BadConfig,
BadUserList,
AllServersDown,
AuthenticationError
}

View File

@@ -47,8 +47,10 @@ use tokio::{
use std::collections::HashMap;
use std::sync::Arc;
mod admin;
mod client;
mod config;
mod userlist;
mod constants;
mod errors;
mod messages;
@@ -93,6 +95,15 @@ async fn main() {
}
};
// Prepare user list
match userlist::parse("userlist.json").await {
Ok(_) => (),
Err(err) => {
error!("Userlist parse error: {:?}", err);
return;
}
};
let config = get_config();
let addr = format!("{}:{}", config.general.host, config.general.port);

View File

@@ -1,5 +1,7 @@
/// Helper functions to send one-off protocol messages
/// and handle TcpStream (TCP socket).
use bytes::{Buf, BufMut, BytesMut};
use md5::{Digest, Md5};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufReader};
@@ -7,10 +9,125 @@ use tokio::net::{
tcp::{OwnedReadHalf, OwnedWriteHalf},
TcpStream,
};
use std::collections::HashMap;
use log::{error};
use crate::errors::Error;
use std::collections::HashMap;
use rand::Rng;
use crate::userlist::get_user_list;
/// Postgres data type mappings
/// used in RowDescription ('T') message.
pub enum DataType {
Text,
Int4,
Numeric,
}
impl From<&DataType> for i32 {
fn from(data_type: &DataType) -> i32 {
match data_type {
DataType::Text => 25,
DataType::Int4 => 23,
DataType::Numeric => 1700,
}
}
}
/**
1. Generate salt (4 bytes of random data)
md5(concat(md5(concat(password, username)), random-salt)))
2. Send md5 auth request
3. recieve PasswordMessage with salt.
4. refactor md5_password function to be reusable
5. check username hash combo against file
6. AuthenticationOk or ErrorResponse
**/
pub async fn start_auth(stream: &mut TcpStream, user_name: &String) -> Result<(), Error> {
let mut rng = rand::thread_rng();
//Generate random 4 byte salt
let salt = rng.gen::<u32>();
// Send AuthenticationMD5Password request
send_md5_request(stream, salt).await?;
let code = match stream.read_u8().await {
Ok(code) => code as char,
Err(_) => return Err(Error::AuthenticationError),
};
match code {
// Password response
'p' => {
fetch_password_and_authenticate(stream, &user_name, &salt).await?;
Ok(auth_ok(stream).await?)
}
_ => {
error!("Unknown code: {}", code);
return Err(Error::AuthenticationError);
}
}
}
pub async fn send_md5_request(stream: &mut TcpStream, salt: u32) -> Result<(), Error> {
let mut authentication_md5password = BytesMut::with_capacity(12);
authentication_md5password.put_u8(b'R');
authentication_md5password.put_i32(12);
authentication_md5password.put_i32(5);
authentication_md5password.put_u32(salt);
// Send AuthenticationMD5Password request
Ok(write_all(stream, authentication_md5password).await?)
}
pub async fn fetch_password_and_authenticate(stream: &mut TcpStream, user_name: &String, salt: &u32) -> Result<(), Error> {
/**
1. How do I store the lists of users and paswords? clear text or hash?? wtf
2. Add auth to tests
**/
let len = match stream.read_i32().await {
Ok(len) => len,
Err(_) => return Err(Error::AuthenticationError),
};
// Read whatever is left.
let mut password_hash = vec![0u8; len as usize - 4];
match stream.read_exact(&mut password_hash).await {
Ok(_) => (),
Err(_) => return Err(Error::AuthenticationError),
};
let user_list = get_user_list();
let mut password: String = String::new();
match user_list.get(&user_name) {
Some(&p) => password = p,
None => return Err(Error::AuthenticationError),
}
let mut md5 = Md5::new();
// concat('md5', md5(concat(md5(concat(password, username)), random-salt)))
// First pass
md5.update(&password.as_bytes());
md5.update(&user_name.as_bytes());
let output = md5.finalize_reset();
// Second pass
md5.update(format!("{:x}", output));
md5.update(salt.to_be_bytes().to_vec());
let password_string: String = String::from_utf8(password_hash).expect("Could not get password hash");
match format!("md5{:x}", md5.finalize()) == password_string {
true => Ok(()),
_ => Err(Error::AuthenticationError)
}
}
/// Tell the client that authentication handshake completed successfully.
pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
@@ -91,9 +208,8 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
}
}
/// Parse StartupMessage parameters.
/// e.g. user, database, application_name, etc.
pub fn parse_startup(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
/// Parse the params the server sends as a key/value format.
pub fn parse_params(mut bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
let mut result = HashMap::new();
let mut buf = Vec::new();
let mut tmp = String::new();
@@ -115,7 +231,7 @@ pub fn parse_startup(mut bytes: BytesMut) -> Result<HashMap<String, String>, Err
// Expect pairs of name and value
// and at least one pair to be present.
if buf.len() % 2 != 0 && buf.len() >= 2 {
if buf.len() % 2 != 0 || buf.len() < 2 {
return Err(Error::ClientBadStartup);
}
@@ -127,6 +243,14 @@ pub fn parse_startup(mut bytes: BytesMut) -> Result<HashMap<String, String>, Err
i += 2;
}
Ok(result)
}
/// Parse StartupMessage parameters.
/// e.g. user, database, application_name, etc.
pub fn parse_startup(bytes: BytesMut) -> Result<HashMap<String, String>, Error> {
let result = parse_params(bytes)?;
// Minimum required parameters
// I want to have the user at the very minimum, according to the protocol spec.
if !result.contains_key("user") {
@@ -252,68 +376,17 @@ pub async fn show_response(
// 3. CommandComplete
// 4. ReadyForQuery
// RowDescription
let mut row_desc = BytesMut::new();
// Number of columns: 1
row_desc.put_i16(1);
// Column name
row_desc.put_slice(&format!("{}\0", name).as_bytes());
// Doesn't belong to any table
row_desc.put_i32(0);
// Doesn't belong to any table
row_desc.put_i16(0);
// Text
row_desc.put_i32(25);
// Text size = variable (-1)
row_desc.put_i16(-1);
// Type modifier: none that I know
row_desc.put_i32(0);
// Format being used: text (0), binary (1)
row_desc.put_i16(0);
// DataRow
let mut data_row = BytesMut::new();
// Number of columns
data_row.put_i16(1);
// Size of the column content (length of the string really)
data_row.put_i32(value.len() as i32);
// The content
data_row.put_slice(value.as_bytes());
// CommandComplete
let mut command_complete = BytesMut::new();
// Number of rows returned (just one)
command_complete.put_slice(&b"SELECT 1\0"[..]);
// The final messages sent to the client
let mut res = BytesMut::new();
// RowDescription
res.put_u8(b'T');
res.put_i32(row_desc.len() as i32 + 4);
res.put(row_desc);
res.put(row_description(&vec![(name, DataType::Text)]));
// DataRow
res.put_u8(b'D');
res.put_i32(data_row.len() as i32 + 4);
res.put(data_row);
res.put(data_row(&vec![value.to_string()]));
// CommandComplete
res.put_u8(b'C');
res.put_i32(command_complete.len() as i32 + 4);
res.put(command_complete);
res.put(command_complete("SELECT 1"));
// ReadyForQuery
res.put_u8(b'Z');
@@ -323,6 +396,77 @@ pub async fn show_response(
write_all_half(stream, res).await
}
pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut {
let mut res = BytesMut::new();
let mut row_desc = BytesMut::new();
// how many colums we are storing
row_desc.put_i16(columns.len() as i16);
for (name, data_type) in columns {
// Column name
row_desc.put_slice(&format!("{}\0", name).as_bytes());
// Doesn't belong to any table
row_desc.put_i32(0);
// Doesn't belong to any table
row_desc.put_i16(0);
// Text
row_desc.put_i32(data_type.into());
// Text size = variable (-1)
let type_size = match data_type {
DataType::Text => -1,
DataType::Int4 => 4,
DataType::Numeric => -1,
};
row_desc.put_i16(type_size);
// Type modifier: none that I know
row_desc.put_i32(-1);
// Format being used: text (0), binary (1)
row_desc.put_i16(0);
}
res.put_u8(b'T');
res.put_i32(row_desc.len() as i32 + 4);
res.put(row_desc);
res
}
pub fn data_row(row: &Vec<String>) -> BytesMut {
let mut res = BytesMut::new();
let mut data_row = BytesMut::new();
data_row.put_i16(row.len() as i16);
for column in row {
let column = column.as_bytes();
data_row.put_i32(column.len() as i32);
data_row.put_slice(&column);
}
res.put_u8(b'D');
res.put_i32(data_row.len() as i32 + 4);
res.put(data_row);
res
}
pub fn command_complete(command: &str) -> BytesMut {
let cmd = BytesMut::from(format!("{}\0", command).as_bytes());
let mut res = BytesMut::new();
res.put_u8(b'C');
res.put_i32(cmd.len() as i32 + 4);
res.put(cmd);
res
}
/// Write all data in the buffer to the TcpStream.
pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Error> {
match stream.write_all(&buf).await {

View File

@@ -49,7 +49,8 @@ impl ConnectionPool {
for shard_idx in shard_ids {
let shard = &config.shards[&shard_idx];
let mut pools = Vec::new();
let mut replica_addresses = Vec::new();
let mut servers = Vec::new();
let mut replica_number = 0;
for server in shard.servers.iter() {
let role = match server.2.as_ref() {
@@ -65,9 +66,14 @@ impl ConnectionPool {
host: server.0.clone(),
port: server.1.to_string(),
role: role,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
};
if role == Role::Replica {
replica_number += 1;
}
let manager = ServerPool::new(
address.clone(),
config.user.clone(),
@@ -87,11 +93,11 @@ impl ConnectionPool {
.unwrap();
pools.push(pool);
replica_addresses.push(address);
servers.push(address);
}
shards.push(pools);
addresses.push(replica_addresses);
addresses.push(servers);
banlist.push(HashMap::new());
}
@@ -126,10 +132,22 @@ impl ConnectionPool {
};
let mut proxy = connection.0;
let _address = connection.1;
let address = connection.1;
let server = &mut *proxy;
server_infos.push(server.server_info());
let server_info = server.server_info();
if server_infos.len() > 0 {
// Compare against the last server checked.
if server_info != server_infos[server_infos.len() - 1] {
warn!(
"{:?} has different server configuration than the last server",
address
);
}
}
server_infos.push(server_info);
}
}
@@ -324,6 +342,22 @@ impl ConnectionPool {
pub fn servers(&self, shard: usize) -> usize {
self.addresses[shard].len()
}
pub fn databases(&self) -> usize {
let mut databases = 0;
for shard in 0..self.shards() {
databases += self.servers(shard);
}
databases
}
pub fn pool_state(&self, shard: usize, server: usize) -> bb8::State {
self.databases[shard][server].state()
}
pub fn address(&self, shard: usize, server: usize) -> &Address {
&self.addresses[shard][server]
}
}
pub struct ServerPool {

View File

@@ -1,12 +1,17 @@
use log::info;
use log::{debug, info};
use once_cell::sync::Lazy;
use parking_lot::Mutex;
use statsd::Client;
/// Events collector and publisher.
use tokio::sync::mpsc::{Receiver, Sender};
use std::collections::HashMap;
use crate::config::get_config;
// Stats used in SHOW STATS
static LATEST_STATS: Lazy<Mutex<HashMap<String, i64>>> = Lazy::new(|| Mutex::new(HashMap::new()));
static STAT_PERIOD: u64 = 15000; //15 seconds
#[derive(Debug, Clone, Copy)]
enum EventName {
CheckoutTime,
@@ -182,16 +187,6 @@ impl Reporter {
let _ = self.tx.try_send(event);
}
// pub fn flush_to_statsd(&self) {
// let event = Event {
// name: EventName::FlushStatsToStatsD,
// value: 0,
// process_id: None,
// };
// let _ = self.tx.try_send(event);
// }
}
pub struct Collector {
@@ -217,7 +212,15 @@ impl Collector {
("total_xact_count", 0),
("total_sent", 0),
("total_received", 0),
("total_xact_time", 0),
("total_query_time", 0),
("total_wait_time", 0),
("avg_xact_time", 0),
("avg_query_time", 0),
("avg_xact_count", 0),
("avg_sent", 0),
("avg_received", 0),
("avg_wait_time", 0),
("maxwait_us", 0),
("maxwait", 0),
("cl_waiting", 0),
@@ -229,11 +232,18 @@ impl Collector {
("sv_tested", 0),
]);
let mut client_server_states: HashMap<i32, EventName> = HashMap::new();
let tx = self.tx.clone();
// Stats saved after each iteration of the flush event. Used in calculation
// of averages in the last flush period.
let mut old_stats: HashMap<String, i64> = HashMap::new();
// Track which state the client and server are at any given time.
let mut client_server_states: HashMap<i32, EventName> = HashMap::new();
// Flush stats to StatsD and calculate averages every 15 seconds.
let tx = self.tx.clone();
tokio::task::spawn(async move {
let mut interval = tokio::time::interval(tokio::time::Duration::from_millis(15000));
let mut interval =
tokio::time::interval(tokio::time::Duration::from_millis(STAT_PERIOD));
loop {
interval.tick().await;
let _ = tx.try_send(Event {
@@ -244,6 +254,7 @@ impl Collector {
}
});
// The collector loop
loop {
let stat = match self.rx.recv().await {
Some(stat) => stat,
@@ -280,10 +291,11 @@ impl Collector {
*counter += stat.value;
let counter = stats.entry("maxwait_us").or_insert(0);
let mic_part = stat.value % 1_000_000;
// Report max time here
if stat.value > *counter {
*counter = stat.value;
if mic_part > *counter {
*counter = mic_part;
}
let counter = stats.entry("maxwait").or_insert(0);
@@ -309,6 +321,7 @@ impl Collector {
}
EventName::FlushStatsToStatsD => {
// Calculate connection states
for (_, state) in &client_server_states {
match state {
EventName::ClientActive => {
@@ -350,13 +363,51 @@ impl Collector {
};
}
info!("{:?}", stats);
// Calculate averages
for stat in &[
"avg_query_count",
"avgxact_count",
"avg_sent",
"avg_received",
"avg_wait_time",
] {
let total_name = stat.replace("avg_", "total_");
let old_value = old_stats.entry(total_name.clone()).or_insert(0);
let new_value = stats.get(total_name.as_str()).unwrap_or(&0).to_owned();
let avg = (new_value - *old_value) / (STAT_PERIOD as i64 / 1_000); // Avg / second
stats.insert(stat, avg);
*old_value = new_value;
}
debug!("{:?}", stats);
// Update latest stats used in SHOW STATS
let mut guard = LATEST_STATS.lock();
for (key, value) in &stats {
guard.insert(key.to_string(), value.clone());
}
let mut pipeline = self.client.pipeline();
for (key, value) in stats.iter_mut() {
for (key, value) in stats.iter() {
pipeline.gauge(key, *value as f64);
*value = 0;
}
// These are re-calculated every iteration of the loop, so we don't want to add values
// from the last iteration.
for stat in &[
"cl_active",
"cl_waiting",
"cl_idle",
"sv_idle",
"sv_active",
"sv_tested",
"sv_login",
"maxwait",
"maxwait_us",
] {
stats.insert(stat, 0);
}
pipeline.send(&self.client);
@@ -365,3 +416,7 @@ impl Collector {
}
}
}
pub fn get_stats() -> HashMap<String, i64> {
LATEST_STATS.lock().clone()
}

4
src/userlist.json Normal file
View File

@@ -0,0 +1,4 @@
{
"sven": "clear_text_password",
"sharding_user": "sharding_user"
}

57
src/userlist.rs Normal file
View File

@@ -0,0 +1,57 @@
use arc_swap::{ArcSwap, Guard};
use log::{error};
use once_cell::sync::Lazy;
use tokio::fs::File;
use tokio::io::AsyncReadExt;
use std::collections::{HashMap};
use std::sync::Arc;
use crate::errors::Error;
pub type UserList = HashMap<String, String>;
static USER_LIST: Lazy<ArcSwap<UserList>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::new()));
pub fn get_user_list() -> Guard<Arc<UserList>> {
USER_LIST.load()
}
/// Parse the user list.
pub async fn parse(path: &str) -> Result<(), Error> {
let mut contents = String::new();
let mut file = match File::open(path).await {
Ok(file) => file,
Err(err) => {
error!("Could not open '{}': {}", path, err.to_string());
return Err(Error::BadConfig);
}
};
match file.read_to_string(&mut contents).await {
Ok(_) => (),
Err(err) => {
error!("Could not read config file: {}", err.to_string());
return Err(Error::BadConfig);
}
};
let map: HashMap<String, String> = serde_json::from_str(&contents).expect("JSON was not well-formatted");
USER_LIST.store(Arc::new(map.clone()));
Ok(())
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn test_config() {
parse("userlist.json").await.unwrap();
assert_eq!(get_user_list()["sven"], "clear_text_password");
assert_eq!(get_user_list()["sharding_user"], "sharding_user");
}
}