diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index e311904..22ad483 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -61,6 +61,7 @@ cd ../.. psql -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null +psql -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null (! psql -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null) # Start PgCat in debug to demonstrate failover better diff --git a/src/admin.rs b/src/admin.rs index 43c46af..fb4a718 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -4,14 +4,19 @@ use tokio::net::tcp::OwnedWriteHalf; use std::collections::HashMap; -use crate::config::{get_config, parse}; -use crate::constants::{OID_NUMERIC, OID_TEXT}; +use crate::config::{get_config, parse, Role}; +use crate::constants::{OID_INT4, OID_NUMERIC, OID_TEXT}; use crate::errors::Error; -use crate::messages::write_all_half; +use crate::messages::{custom_protocol_response_ok, error_response, write_all_half}; +use crate::pool::ConnectionPool; use crate::stats::get_stats; /// Handle admin client -pub async fn handle_admin(stream: &mut OwnedWriteHalf, mut query: BytesMut) -> Result<(), Error> { +pub async fn handle_admin( + stream: &mut OwnedWriteHalf, + mut query: BytesMut, + pool: ConnectionPool, +) -> Result<(), Error> { let code = query.get_u8() as char; if code != 'Q' { @@ -23,6 +28,8 @@ pub async fn handle_admin(stream: &mut OwnedWriteHalf, mut query: BytesMut) -> R .to_string() .to_ascii_uppercase(); + trace!("Admin query: {}", query); + if query.starts_with("SHOW STATS") { trace!("SHOW STATS"); show_stats(stream).await @@ -32,13 +39,147 @@ pub async fn handle_admin(stream: &mut OwnedWriteHalf, mut query: BytesMut) -> R } else if query.starts_with("SHOW CONFIG") { trace!("SHOW CONFIG"); show_config(stream).await + } else if query.starts_with("SHOW DATABASES") { + trace!("SHOW DATABASES"); + show_databases(stream, &pool).await + } else if query.starts_with("SET ") { + trace!("SET"); + ignore_set(stream).await } else { - Err(Error::ProtocolSyncError) + error_response(stream, "Unsupported query against the admin database").await } } +/// SHOW DATABASES +async fn show_databases(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> Result<(), Error> { + let guard = get_config(); + let config = &*guard.clone(); + drop(guard); + + let columns = [ + "name", + "host", + "port", + "database", + "force_user", + "pool_size", + "min_pool_size", + "reserve_pool", + "pool_mode", + "max_connections", + "current_connections", + "paused", + "disabled", + ]; + + let types = [ + OID_TEXT, OID_TEXT, OID_TEXT, OID_TEXT, OID_TEXT, OID_INT4, OID_INT4, OID_INT4, OID_TEXT, + OID_INT4, OID_INT4, OID_INT4, OID_INT4, + ]; + + let mut res = BytesMut::new(); + let mut row_desc = BytesMut::new(); + row_desc.put_i16(columns.len() as i16); + + for (i, column) in columns.iter().enumerate() { + row_desc.put_slice(&format!("{}\0", column).as_bytes()); + + // Doesn't belong to any table + row_desc.put_i32(0); + + // Doesn't belong to any table + row_desc.put_i16(0); + + // Data type + row_desc.put_i32(types[i]); + + // text size = variable (-1) + row_desc.put_i16(if types[i] == OID_TEXT { -1 } else { 4 }); + + // Type modifier: none that I know + row_desc.put_i32(-1); + + // Format being used: text (0), binary (1) + row_desc.put_i16(0); + } + + res.put_u8(b'T'); + res.put_i32(row_desc.len() as i32 + 4); + res.put(row_desc); + + for shard in 0..pool.shards() { + let database_name = &config.shards[&shard.to_string()].database; + let mut replica_count = 0; + for server in 0..pool.servers(shard) { + // DataRow + let mut data_row = BytesMut::new(); + data_row.put_i16(columns.len() as i16); + + let address = pool.address(shard, server); + let role = address.role.to_string(); + let name = match role.as_ref() { + "primary" => format!("shard_{}_primary", shard), + "replica" => format!("shard_{}_replica_{}", shard, replica_count), + _ => unreachable!(), + }; + let connections = pool.connections(shard, server); + + let data = HashMap::from([ + ("host", address.host.to_string()), + ("port", address.port.to_string()), + ("role", role), + ("name", name), + ("database", database_name.to_string()), + ("force_user", config.user.name.to_string()), + ("pool_size", config.general.pool_size.to_string()), + ("min_pool_size", "0".to_string()), + ("reserve_pool", "0".to_string()), + ("pool_mode", config.general.pool_mode.to_string()), + // There is only one user support at the moment, + // so max_connections = num of users * pool_size = 1 * pool_size. + ("max_connections", config.general.pool_size.to_string()), + ("current_connections", connections.connections.to_string()), + ("paused", "0".to_string()), + ("disabled", "0".to_string()), + ]); + + for column in &columns { + let value = data[column].as_bytes(); + + data_row.put_i32(value.len() as i32); + data_row.put_slice(&value); + } + + res.put_u8(b'D'); + res.put_i32(data_row.len() as i32 + 4); + res.put(data_row); + + if address.role == Role::Replica { + replica_count += 1; + } + } + } + + let command_complete = BytesMut::from(&"SHOW\0"[..]); + res.put_u8(b'C'); + res.put_i32(command_complete.len() as i32 + 4); + res.put(command_complete); + + res.put_u8(b'Z'); + res.put_i32(5); + res.put_u8(b'I'); + + write_all_half(stream, res).await +} + +/// Ignore any SET commands the client sends. +/// This is common initialization done by ORMs. +async fn ignore_set(stream: &mut OwnedWriteHalf) -> Result<(), Error> { + custom_protocol_response_ok(stream, "SET").await +} + /// RELOAD -pub async fn reload(stream: &mut OwnedWriteHalf) -> Result<(), Error> { +async fn reload(stream: &mut OwnedWriteHalf) -> Result<(), Error> { info!("Reloading config"); let config = get_config(); @@ -66,7 +207,7 @@ pub async fn reload(stream: &mut OwnedWriteHalf) -> Result<(), Error> { write_all_half(stream, res).await } -pub async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> { +async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> { let guard = get_config(); let config = &*guard.clone(); let config: HashMap = config.into(); @@ -153,7 +294,7 @@ pub async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> { } /// SHOW STATS -pub async fn show_stats(stream: &mut OwnedWriteHalf) -> Result<(), Error> { +async fn show_stats(stream: &mut OwnedWriteHalf) -> Result<(), Error> { let columns = [ "database", "total_xact_count", diff --git a/src/client.rs b/src/client.rs index fb1a4b0..f73ef22 100644 --- a/src/client.rs +++ b/src/client.rs @@ -238,7 +238,7 @@ impl Client { // Handle admin database real quick if self.admin { trace!("Handling admin command"); - handle_admin(&mut self.write, message).await?; + handle_admin(&mut self.write, message, pool.clone()).await?; continue; } diff --git a/src/config.rs b/src/config.rs index 203f410..1687ff3 100644 --- a/src/config.rs +++ b/src/config.rs @@ -19,6 +19,15 @@ pub enum Role { Replica, } +impl ToString for Role { + fn to_string(&self) -> String { + match *self { + Role::Primary => "primary".to_string(), + Role::Replica => "replica".to_string(), + } + } +} + impl PartialEq> for Role { fn eq(&self, other: &Option) -> bool { match other { diff --git a/src/constants.rs b/src/constants.rs index 2560975..b829eaa 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -26,3 +26,5 @@ pub const MESSAGE_TERMINATOR: u8 = 0; // pub const OID_NUMERIC: i32 = 1700; pub const OID_TEXT: i32 = 25; +pub const OID_INT4: i32 = 23; // int +pub const _OID_INT8: i32 = 20; // bigint diff --git a/src/pool.rs b/src/pool.rs index 26d172b..431a909 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -336,6 +336,14 @@ impl ConnectionPool { pub fn servers(&self, shard: usize) -> usize { self.addresses[shard].len() } + + pub fn connections(&self, shard: usize, server: usize) -> bb8::State { + self.databases[shard][server].state() + } + + pub fn address(&self, shard: usize, server: usize) -> &Address { + &self.addresses[shard][server] + } } pub struct ServerPool {