diff --git a/src/client.rs b/src/client.rs index a9e7a60..154716c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -63,6 +63,7 @@ impl Client { client_server_map: ClientServerMap, transaction_mode: bool, default_server_role: Option, + server_info: BytesMut, ) -> Result { loop { // Could be StartupMessage or SSLRequest @@ -102,7 +103,7 @@ impl Client { let secret_key: i32 = rand::random(); auth_ok(&mut stream).await?; - server_parameters(&mut stream).await?; + write_all(&mut stream, server_info).await?; backend_key_data(&mut stream, process_id, secret_key).await?; ready_for_query(&mut stream).await?; diff --git a/src/config.rs b/src/config.rs index 79c2c0b..d2e050b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -90,6 +90,11 @@ pub async fn parse(path: &str) -> Result { let mut dup_check = HashSet::new(); let mut primary_count = 0; + if shard.1.servers.len() == 0 { + println!("> Shard {} has no servers configured", shard.0); + return Err(Error::BadConfig); + } + for server in &shard.1.servers { dup_check.insert(server); diff --git a/src/main.rs b/src/main.rs index 05903d6..37f820f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -86,7 +86,7 @@ async fn main() { ); println!("> Connection timeout: {}ms", config.general.connect_timeout); - let pool = ConnectionPool::from_config(config.clone(), client_server_map.clone()).await; + let mut pool = ConnectionPool::from_config(config.clone(), client_server_map.clone()).await; let transaction_mode = config.general.pool_mode == "transaction"; let default_server_role = match config.query_router.default_role.as_ref() { "any" => None, @@ -98,11 +98,20 @@ async fn main() { } }; + let server_info = match pool.validate().await { + Ok(info) => info, + Err(err) => { + println!("> Could not validate connection pool: {:?}", err); + return; + } + }; + println!("> Waiting for clients..."); loop { let pool = pool.clone(); let client_server_map = client_server_map.clone(); + let server_info = server_info.clone(); let (socket, addr) = match listener.accept().await { Ok((socket, addr)) => (socket, addr), @@ -124,6 +133,7 @@ async fn main() { client_server_map, transaction_mode, default_server_role, + server_info, ) .await { diff --git a/src/messages.rs b/src/messages.rs index 5f17d8d..cd99de7 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -8,10 +8,8 @@ use crate::errors::Error; // This is a funny one. `psql` parses this to figure out which // queries to send when using shortcuts, e.g. \d+. -// -// TODO: Actually get the version from the server itself. -// -const SERVER_VESION: &str = "12.9 (Ubuntu 12.9-0ubuntu0.20.04.1)"; +// No longer used. Keeping it here until I'm sure we don't need it again. +const _SERVER_VESION: &str = "12.9 (Ubuntu 12.9-0ubuntu0.20.04.1)"; /// Tell the client that authentication handshake completed successfully. pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> { @@ -27,12 +25,12 @@ pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> { /// Send server parameters to the client. This will tell the client /// what server version and what's the encoding we're using. // -// TODO: Forward these from the server instead of hardcoding. +// No longer used. Keeping it here until I'm sure we don't need it again. // -pub async fn server_parameters(stream: &mut TcpStream) -> Result<(), Error> { +pub async fn _server_parameters(stream: &mut TcpStream) -> Result<(), Error> { let client_encoding = BytesMut::from(&b"client_encoding\0UTF8\0"[..]); let server_version = - BytesMut::from(&format!("server_version\0{}\0", SERVER_VESION).as_bytes()[..]); + BytesMut::from(&format!("server_version\0{}\0", _SERVER_VESION).as_bytes()[..]); // Client encoding let len = client_encoding.len() as i32 + 4; // TODO: add more parameters here diff --git a/src/pool.rs b/src/pool.rs index 8c92bc6..3432330 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -1,6 +1,7 @@ /// Pooling and failover and banlist. use async_trait::async_trait; use bb8::{ManageConnection, Pool, PooledConnection}; +use bytes::BytesMut; use chrono::naive::NaiveDateTime; use crate::config::{Address, Config, Role, User}; @@ -105,6 +106,38 @@ impl ConnectionPool { } } + /// Connect to all shards and grab server information. + /// Return server information we will pass to the clients + /// when they connect. + pub async fn validate(&mut self) -> Result { + let mut server_infos = Vec::new(); + + for shard in 0..self.shards() { + // TODO: query all primary and replicas in the shard configuration. + let connection = match self.get(Some(shard), None).await { + Ok(conn) => conn, + Err(err) => { + println!("> Shard {} down or misconfigured.", shard); + return Err(err); + } + }; + + let mut proxy = connection.0; + let _address = connection.1; + let server = &mut *proxy; + + server_infos.push(server.server_info()); + } + + // TODO: compare server information to make sure + // all shards are running identical configurations. + if server_infos.len() == 0 { + return Err(Error::AllServersDown); + } + + Ok(server_infos[0].clone()) + } + /// Get a connection from the pool. pub async fn get( &mut self,