Pass real server info to the client (#10)

This commit is contained in:
Lev Kokotov
2022-02-11 22:19:49 -08:00
committed by GitHub
parent ab8573c94f
commit 526b9eb666
5 changed files with 56 additions and 9 deletions

View File

@@ -63,6 +63,7 @@ impl Client {
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
transaction_mode: bool, transaction_mode: bool,
default_server_role: Option<Role>, default_server_role: Option<Role>,
server_info: BytesMut,
) -> Result<Client, Error> { ) -> Result<Client, Error> {
loop { loop {
// Could be StartupMessage or SSLRequest // Could be StartupMessage or SSLRequest
@@ -102,7 +103,7 @@ impl Client {
let secret_key: i32 = rand::random(); let secret_key: i32 = rand::random();
auth_ok(&mut stream).await?; auth_ok(&mut stream).await?;
server_parameters(&mut stream).await?; write_all(&mut stream, server_info).await?;
backend_key_data(&mut stream, process_id, secret_key).await?; backend_key_data(&mut stream, process_id, secret_key).await?;
ready_for_query(&mut stream).await?; ready_for_query(&mut stream).await?;

View File

@@ -90,6 +90,11 @@ pub async fn parse(path: &str) -> Result<Config, Error> {
let mut dup_check = HashSet::new(); let mut dup_check = HashSet::new();
let mut primary_count = 0; let mut primary_count = 0;
if shard.1.servers.len() == 0 {
println!("> Shard {} has no servers configured", shard.0);
return Err(Error::BadConfig);
}
for server in &shard.1.servers { for server in &shard.1.servers {
dup_check.insert(server); dup_check.insert(server);

View File

@@ -86,7 +86,7 @@ async fn main() {
); );
println!("> Connection timeout: {}ms", config.general.connect_timeout); println!("> Connection timeout: {}ms", config.general.connect_timeout);
let pool = ConnectionPool::from_config(config.clone(), client_server_map.clone()).await; let mut pool = ConnectionPool::from_config(config.clone(), client_server_map.clone()).await;
let transaction_mode = config.general.pool_mode == "transaction"; let transaction_mode = config.general.pool_mode == "transaction";
let default_server_role = match config.query_router.default_role.as_ref() { let default_server_role = match config.query_router.default_role.as_ref() {
"any" => None, "any" => None,
@@ -98,11 +98,20 @@ async fn main() {
} }
}; };
let server_info = match pool.validate().await {
Ok(info) => info,
Err(err) => {
println!("> Could not validate connection pool: {:?}", err);
return;
}
};
println!("> Waiting for clients..."); println!("> Waiting for clients...");
loop { loop {
let pool = pool.clone(); let pool = pool.clone();
let client_server_map = client_server_map.clone(); let client_server_map = client_server_map.clone();
let server_info = server_info.clone();
let (socket, addr) = match listener.accept().await { let (socket, addr) = match listener.accept().await {
Ok((socket, addr)) => (socket, addr), Ok((socket, addr)) => (socket, addr),
@@ -124,6 +133,7 @@ async fn main() {
client_server_map, client_server_map,
transaction_mode, transaction_mode,
default_server_role, default_server_role,
server_info,
) )
.await .await
{ {

View File

@@ -8,10 +8,8 @@ use crate::errors::Error;
// This is a funny one. `psql` parses this to figure out which // This is a funny one. `psql` parses this to figure out which
// queries to send when using shortcuts, e.g. \d+. // queries to send when using shortcuts, e.g. \d+.
// // No longer used. Keeping it here until I'm sure we don't need it again.
// TODO: Actually get the version from the server itself. const _SERVER_VESION: &str = "12.9 (Ubuntu 12.9-0ubuntu0.20.04.1)";
//
const SERVER_VESION: &str = "12.9 (Ubuntu 12.9-0ubuntu0.20.04.1)";
/// Tell the client that authentication handshake completed successfully. /// Tell the client that authentication handshake completed successfully.
pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> { pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
@@ -27,12 +25,12 @@ pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
/// Send server parameters to the client. This will tell the client /// Send server parameters to the client. This will tell the client
/// what server version and what's the encoding we're using. /// what server version and what's the encoding we're using.
// //
// TODO: Forward these from the server instead of hardcoding. // No longer used. Keeping it here until I'm sure we don't need it again.
// //
pub async fn server_parameters(stream: &mut TcpStream) -> Result<(), Error> { pub async fn _server_parameters(stream: &mut TcpStream) -> Result<(), Error> {
let client_encoding = BytesMut::from(&b"client_encoding\0UTF8\0"[..]); let client_encoding = BytesMut::from(&b"client_encoding\0UTF8\0"[..]);
let server_version = let server_version =
BytesMut::from(&format!("server_version\0{}\0", SERVER_VESION).as_bytes()[..]); BytesMut::from(&format!("server_version\0{}\0", _SERVER_VESION).as_bytes()[..]);
// Client encoding // Client encoding
let len = client_encoding.len() as i32 + 4; // TODO: add more parameters here let len = client_encoding.len() as i32 + 4; // TODO: add more parameters here

View File

@@ -1,6 +1,7 @@
/// Pooling and failover and banlist. /// Pooling and failover and banlist.
use async_trait::async_trait; use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection}; use bb8::{ManageConnection, Pool, PooledConnection};
use bytes::BytesMut;
use chrono::naive::NaiveDateTime; use chrono::naive::NaiveDateTime;
use crate::config::{Address, Config, Role, User}; use crate::config::{Address, Config, Role, User};
@@ -105,6 +106,38 @@ impl ConnectionPool {
} }
} }
/// Connect to all shards and grab server information.
/// Return server information we will pass to the clients
/// when they connect.
pub async fn validate(&mut self) -> Result<BytesMut, Error> {
let mut server_infos = Vec::new();
for shard in 0..self.shards() {
// TODO: query all primary and replicas in the shard configuration.
let connection = match self.get(Some(shard), None).await {
Ok(conn) => conn,
Err(err) => {
println!("> Shard {} down or misconfigured.", shard);
return Err(err);
}
};
let mut proxy = connection.0;
let _address = connection.1;
let server = &mut *proxy;
server_infos.push(server.server_info());
}
// TODO: compare server information to make sure
// all shards are running identical configurations.
if server_infos.len() == 0 {
return Err(Error::AllServersDown);
}
Ok(server_infos[0].clone())
}
/// Get a connection from the pool. /// Get a connection from the pool.
pub async fn get( pub async fn get(
&mut self, &mut self,