diff --git a/src/admin.rs b/src/admin.rs index fb4a718..0d2e197 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -5,9 +5,8 @@ use tokio::net::tcp::OwnedWriteHalf; use std::collections::HashMap; use crate::config::{get_config, parse, Role}; -use crate::constants::{OID_INT4, OID_NUMERIC, OID_TEXT}; use crate::errors::Error; -use crate::messages::{custom_protocol_response_ok, error_response, write_all_half}; +use crate::messages::*; use crate::pool::ConnectionPool; use crate::stats::get_stats; @@ -56,115 +55,66 @@ async fn show_databases(stream: &mut OwnedWriteHalf, pool: &ConnectionPool) -> R let config = &*guard.clone(); drop(guard); - let columns = [ - "name", - "host", - "port", - "database", - "force_user", - "pool_size", - "min_pool_size", - "reserve_pool", - "pool_mode", - "max_connections", - "current_connections", - "paused", - "disabled", - ]; - - let types = [ - OID_TEXT, OID_TEXT, OID_TEXT, OID_TEXT, OID_TEXT, OID_INT4, OID_INT4, OID_INT4, OID_TEXT, - OID_INT4, OID_INT4, OID_INT4, OID_INT4, + // Columns + let columns = vec![ + ("name", DataType::Text), + ("host", DataType::Text), + ("port", DataType::Text), + ("database", DataType::Text), + ("force_user", DataType::Text), + ("pool_size", DataType::Int4), + ("min_pool_size", DataType::Int4), + ("reserve_pool", DataType::Int4), + ("pool_mode", DataType::Text), + ("max_connections", DataType::Int4), + ("current_connections", DataType::Int4), + ("paused", DataType::Int4), + ("disabled", DataType::Int4), ]; let mut res = BytesMut::new(); - let mut row_desc = BytesMut::new(); - row_desc.put_i16(columns.len() as i16); - for (i, column) in columns.iter().enumerate() { - row_desc.put_slice(&format!("{}\0", column).as_bytes()); - - // Doesn't belong to any table - row_desc.put_i32(0); - - // Doesn't belong to any table - row_desc.put_i16(0); - - // Data type - row_desc.put_i32(types[i]); - - // text size = variable (-1) - row_desc.put_i16(if types[i] == OID_TEXT { -1 } else { 4 }); - - // Type modifier: none that I know - row_desc.put_i32(-1); - - // Format being used: text (0), binary (1) - row_desc.put_i16(0); - } - - res.put_u8(b'T'); - res.put_i32(row_desc.len() as i32 + 4); - res.put(row_desc); + // RowDescription + res.put(row_description(&columns)); for shard in 0..pool.shards() { let database_name = &config.shards[&shard.to_string()].database; let mut replica_count = 0; + for server in 0..pool.servers(shard) { - // DataRow - let mut data_row = BytesMut::new(); - data_row.put_i16(columns.len() as i16); - let address = pool.address(shard, server); - let role = address.role.to_string(); - let name = match role.as_ref() { - "primary" => format!("shard_{}_primary", shard), - "replica" => format!("shard_{}_replica_{}", shard, replica_count), - _ => unreachable!(), + let name = match address.role { + Role::Primary => format!("shard_{}_primary", shard), + + Role::Replica => { + let name = format!("shard_{}_replica_{}", shard, replica_count); + replica_count += 1; + name + } }; - let connections = pool.connections(shard, server); + let pool_state = pool.pool_state(shard, server); - let data = HashMap::from([ - ("host", address.host.to_string()), - ("port", address.port.to_string()), - ("role", role), - ("name", name), - ("database", database_name.to_string()), - ("force_user", config.user.name.to_string()), - ("pool_size", config.general.pool_size.to_string()), - ("min_pool_size", "0".to_string()), - ("reserve_pool", "0".to_string()), - ("pool_mode", config.general.pool_mode.to_string()), - // There is only one user support at the moment, - // so max_connections = num of users * pool_size = 1 * pool_size. - ("max_connections", config.general.pool_size.to_string()), - ("current_connections", connections.connections.to_string()), - ("paused", "0".to_string()), - ("disabled", "0".to_string()), - ]); - - for column in &columns { - let value = data[column].as_bytes(); - - data_row.put_i32(value.len() as i32); - data_row.put_slice(&value); - } - - res.put_u8(b'D'); - res.put_i32(data_row.len() as i32 + 4); - res.put(data_row); - - if address.role == Role::Replica { - replica_count += 1; - } + res.put(data_row(&vec![ + name, // name + address.host.to_string(), // host + address.port.to_string(), // port + database_name.to_string(), // database + config.user.name.to_string(), // force_user + config.general.pool_size.to_string(), // pool_size + "0".to_string(), // min_pool_size + "0".to_string(), // reserve_pool + config.general.pool_mode.to_string(), // pool_mode + config.general.pool_size.to_string(), // max_connections + pool_state.connections.to_string(), // current_connections + "0".to_string(), // paused + "0".to_string(), // disabled + ])); } } - let command_complete = BytesMut::from(&"SHOW\0"[..]); - res.put_u8(b'C'); - res.put_i32(command_complete.len() as i32 + 4); - res.put(command_complete); + res.put(command_complete("SHOW")); + // ReadyForQuery res.put_u8(b'Z'); res.put_i32(5); res.put_u8(b'I'); @@ -194,10 +144,7 @@ async fn reload(stream: &mut OwnedWriteHalf) -> Result<(), Error> { let mut res = BytesMut::new(); // CommandComplete - let command_complete = BytesMut::from(&"RELOAD\0"[..]); - res.put_u8(b'C'); - res.put_i32(command_complete.len() as i32 + 4); - res.put(command_complete); + res.put(command_complete("RELOAD")); // ReadyForQuery res.put_u8(b'Z'); @@ -217,74 +164,31 @@ async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> { let immutables = ["host", "port", "connect_timeout"]; // Columns - let columns = ["key", "value", "default", "changeable"]; - - // RowDescription - let mut row_desc = BytesMut::new(); - row_desc.put_i16(4 as i16); // key, value, default, changeable - - for column in columns { - row_desc.put_slice(&format!("{}\0", column).as_bytes()); - - // Doesn't belong to any table - row_desc.put_i32(0); - - // Doesn't belong to any table - row_desc.put_i16(0); - - // Data type - row_desc.put_i32(OID_TEXT); - - // text size = variable (-1) - row_desc.put_i16(-1); - - // Type modifier: none that I know - row_desc.put_i32(-1); - - // Format being used: text (0), binary (1) - row_desc.put_i16(0); - } + let columns = vec![ + ("key", DataType::Text), + ("value", DataType::Text), + ("default", DataType::Text), + ("changeable", DataType::Text), + ]; // Response data let mut res = BytesMut::new(); - res.put_u8(b'T'); - res.put_i32(row_desc.len() as i32 + 4); - res.put(row_desc); + res.put(row_description(&columns)); // DataRow rows for (key, value) in config { - let mut data_row = BytesMut::new(); - - data_row.put_i16(4 as i16); // key, value, default, changeable - - let key_bytes = key.as_bytes(); - let value = value.as_bytes(); - - data_row.put_i32(key_bytes.len() as i32); - data_row.put_slice(&key_bytes); - - data_row.put_i32(value.len() as i32); - data_row.put_slice(&value); - - data_row.put_i32(1 as i32); - data_row.put_slice(&"-".as_bytes()); - let changeable = if immutables.iter().filter(|col| *col == &key).count() == 1 { - "no".as_bytes() + "no".to_string() } else { - "yes".as_bytes() + "yes".to_string() }; - data_row.put_i32(changeable.len() as i32); - data_row.put_slice(&changeable); - res.put_u8(b'D'); - res.put_i32(data_row.len() as i32 + 4); - res.put(data_row); + let row = vec![key, value, "-".to_string(), changeable]; + + res.put(data_row(&row)); } - res.put_u8(b'C'); - res.put_i32("SHOW CONFIG\0".as_bytes().len() as i32 + 4); - res.put_slice(&"SHOW CONFIG\0".as_bytes()); + res.put(command_complete("SHOW")); res.put_u8(b'Z'); res.put_i32(5); @@ -295,81 +199,38 @@ async fn show_config(stream: &mut OwnedWriteHalf) -> Result<(), Error> { /// SHOW STATS async fn show_stats(stream: &mut OwnedWriteHalf) -> Result<(), Error> { - let columns = [ - "database", - "total_xact_count", - "total_query_count", - "total_received", - "total_sent", - "total_xact_time", - "total_query_time", - "total_wait_time", - "avg_xact_count", - "avg_query_count", - "avg_recv", - "avg_sent", - "avg_xact_time", - "avg_query_time", - "avg_wait_time", + let columns = vec![ + ("database", DataType::Text), + ("total_xact_count", DataType::Numeric), + ("total_query_count", DataType::Numeric), + ("total_received", DataType::Numeric), + ("total_sent", DataType::Numeric), + ("total_xact_time", DataType::Numeric), + ("total_query_time", DataType::Numeric), + ("total_wait_time", DataType::Numeric), + ("avg_xact_count", DataType::Numeric), + ("avg_query_count", DataType::Numeric), + ("avg_recv", DataType::Numeric), + ("avg_sent", DataType::Numeric), + ("avg_xact_time", DataType::Numeric), + ("avg_query_time", DataType::Numeric), + ("avg_wait_time", DataType::Numeric), ]; let stats = get_stats(); let mut res = BytesMut::new(); - let mut row_desc = BytesMut::new(); - let mut data_row = BytesMut::new(); + res.put(row_description(&columns)); - // Number of columns: 1 - row_desc.put_i16(columns.len() as i16); - data_row.put_i16(columns.len() as i16); + let mut row = vec![ + String::from("all shards"), // TODO: per-database stats, + ]; - for (i, column) in columns.iter().enumerate() { - // RowDescription - - // Column name - row_desc.put_slice(&format!("{}\0", column).as_bytes()); - - // Doesn't belong to any table - row_desc.put_i32(0); - - // Doesn't belong to any table - row_desc.put_i16(0); - - // Data type - row_desc.put_i32(if i == 0 { OID_TEXT } else { OID_NUMERIC }); - - // Numeric/text size = variable (-1) - row_desc.put_i16(-1); - - // Type modifier: none that I know - row_desc.put_i32(-1); - - // Format being used: text (0), binary (1) - row_desc.put_i16(0); - - // DataRow - let value = if i == 0 { - String::from("all shards") - } else { - stats.get(&column.to_string()).unwrap_or(&0).to_string() - }; - - data_row.put_i32(value.len() as i32); - data_row.put_slice(value.as_bytes()); + for column in &columns[1..] { + row.push(stats.get(column.0).unwrap_or(&0).to_string()); } - let command_complete = BytesMut::from(&"SHOW\0"[..]); - - res.put_u8(b'T'); - res.put_i32(row_desc.len() as i32 + 4); - res.put(row_desc); - - res.put_u8(b'D'); - res.put_i32(data_row.len() as i32 + 4); - res.put(data_row); - - res.put_u8(b'C'); - res.put_i32(command_complete.len() as i32 + 4); - res.put(command_complete); + res.put(data_row(&row)); + res.put(command_complete("SHOW")); res.put_u8(b'Z'); res.put_i32(5); diff --git a/src/constants.rs b/src/constants.rs index b829eaa..074811e 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -24,7 +24,4 @@ pub const MESSAGE_TERMINATOR: u8 = 0; // // Data types // -pub const OID_NUMERIC: i32 = 1700; -pub const OID_TEXT: i32 = 25; -pub const OID_INT4: i32 = 23; // int pub const _OID_INT8: i32 = 20; // bigint diff --git a/src/messages.rs b/src/messages.rs index 8f7080f..473c8de 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -8,9 +8,26 @@ use tokio::net::{ TcpStream, }; +use crate::errors::Error; use std::collections::HashMap; -use crate::errors::Error; +/// Postgres data type mappings +/// used in RowDescription ('T') message. +pub enum DataType { + Text, + Int4, + Numeric, +} + +impl From<&DataType> for i32 { + fn from(data_type: &DataType) -> i32 { + match data_type { + DataType::Text => 25, + DataType::Int4 => 23, + DataType::Numeric => 1700, + } + } +} /// Tell the client that authentication handshake completed successfully. pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> { @@ -259,68 +276,17 @@ pub async fn show_response( // 3. CommandComplete // 4. ReadyForQuery - // RowDescription - let mut row_desc = BytesMut::new(); - - // Number of columns: 1 - row_desc.put_i16(1); - - // Column name - row_desc.put_slice(&format!("{}\0", name).as_bytes()); - - // Doesn't belong to any table - row_desc.put_i32(0); - - // Doesn't belong to any table - row_desc.put_i16(0); - - // Text - row_desc.put_i32(25); - - // Text size = variable (-1) - row_desc.put_i16(-1); - - // Type modifier: none that I know - row_desc.put_i32(-1); - - // Format being used: text (0), binary (1) - row_desc.put_i16(0); - - // DataRow - let mut data_row = BytesMut::new(); - - // Number of columns - data_row.put_i16(1); - - // Size of the column content (length of the string really) - data_row.put_i32(value.len() as i32); - - // The content - data_row.put_slice(value.as_bytes()); - - // CommandComplete - let mut command_complete = BytesMut::new(); - - // Number of rows returned (just one) - command_complete.put_slice(&b"SELECT 1\0"[..]); - // The final messages sent to the client let mut res = BytesMut::new(); // RowDescription - res.put_u8(b'T'); - res.put_i32(row_desc.len() as i32 + 4); - res.put(row_desc); + res.put(row_description(&vec![(name, DataType::Text)])); // DataRow - res.put_u8(b'D'); - res.put_i32(data_row.len() as i32 + 4); - res.put(data_row); + res.put(data_row(&vec![value.to_string()])); // CommandComplete - res.put_u8(b'C'); - res.put_i32(command_complete.len() as i32 + 4); - res.put(command_complete); + res.put(command_complete("SELECT 1")); // ReadyForQuery res.put_u8(b'Z'); @@ -330,6 +296,77 @@ pub async fn show_response( write_all_half(stream, res).await } +pub fn row_description(columns: &Vec<(&str, DataType)>) -> BytesMut { + let mut res = BytesMut::new(); + let mut row_desc = BytesMut::new(); + + // how many colums we are storing + row_desc.put_i16(columns.len() as i16); + + for (name, data_type) in columns { + // Column name + row_desc.put_slice(&format!("{}\0", name).as_bytes()); + + // Doesn't belong to any table + row_desc.put_i32(0); + + // Doesn't belong to any table + row_desc.put_i16(0); + + // Text + row_desc.put_i32(data_type.into()); + + // Text size = variable (-1) + let type_size = match data_type { + DataType::Text => -1, + DataType::Int4 => 4, + DataType::Numeric => -1, + }; + + row_desc.put_i16(type_size); + + // Type modifier: none that I know + row_desc.put_i32(-1); + + // Format being used: text (0), binary (1) + row_desc.put_i16(0); + } + + res.put_u8(b'T'); + res.put_i32(row_desc.len() as i32 + 4); + res.put(row_desc); + + res +} + +pub fn data_row(row: &Vec) -> BytesMut { + let mut res = BytesMut::new(); + let mut data_row = BytesMut::new(); + + data_row.put_i16(row.len() as i16); + + for column in row { + let column = column.as_bytes(); + data_row.put_i32(column.len() as i32); + data_row.put_slice(&column); + } + + res.put_u8(b'D'); + res.put_i32(data_row.len() as i32 + 4); + res.put(data_row); + + res +} + +pub fn command_complete(command: &str) -> BytesMut { + let cmd = BytesMut::from(format!("{}\0", command).as_bytes()); + let mut res = BytesMut::new(); + res.put_u8(b'C'); + res.put_i32(cmd.len() as i32 + 4); + res.put(cmd); + res +} + /// Write all data in the buffer to the TcpStream. pub async fn write_all(stream: &mut TcpStream, buf: BytesMut) -> Result<(), Error> { match stream.write_all(&buf).await { diff --git a/src/pool.rs b/src/pool.rs index 431a909..92b7663 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -337,7 +337,7 @@ impl ConnectionPool { self.addresses[shard].len() } - pub fn connections(&self, shard: usize, server: usize) -> bb8::State { + pub fn pool_state(&self, shard: usize, server: usize) -> bb8::State { self.databases[shard][server].state() }