Pass real server info to the client (#10)

This commit is contained in:
Lev Kokotov
2022-02-11 22:19:49 -08:00
committed by GitHub
parent ab8573c94f
commit 526b9eb666
5 changed files with 56 additions and 9 deletions

View File

@@ -63,6 +63,7 @@ impl Client {
client_server_map: ClientServerMap,
transaction_mode: bool,
default_server_role: Option<Role>,
server_info: BytesMut,
) -> Result<Client, Error> {
loop {
// Could be StartupMessage or SSLRequest
@@ -102,7 +103,7 @@ impl Client {
let secret_key: i32 = rand::random();
auth_ok(&mut stream).await?;
server_parameters(&mut stream).await?;
write_all(&mut stream, server_info).await?;
backend_key_data(&mut stream, process_id, secret_key).await?;
ready_for_query(&mut stream).await?;

View File

@@ -90,6 +90,11 @@ pub async fn parse(path: &str) -> Result<Config, Error> {
let mut dup_check = HashSet::new();
let mut primary_count = 0;
if shard.1.servers.len() == 0 {
println!("> Shard {} has no servers configured", shard.0);
return Err(Error::BadConfig);
}
for server in &shard.1.servers {
dup_check.insert(server);

View File

@@ -86,7 +86,7 @@ async fn main() {
);
println!("> Connection timeout: {}ms", config.general.connect_timeout);
let pool = ConnectionPool::from_config(config.clone(), client_server_map.clone()).await;
let mut pool = ConnectionPool::from_config(config.clone(), client_server_map.clone()).await;
let transaction_mode = config.general.pool_mode == "transaction";
let default_server_role = match config.query_router.default_role.as_ref() {
"any" => None,
@@ -98,11 +98,20 @@ async fn main() {
}
};
let server_info = match pool.validate().await {
Ok(info) => info,
Err(err) => {
println!("> Could not validate connection pool: {:?}", err);
return;
}
};
println!("> Waiting for clients...");
loop {
let pool = pool.clone();
let client_server_map = client_server_map.clone();
let server_info = server_info.clone();
let (socket, addr) = match listener.accept().await {
Ok((socket, addr)) => (socket, addr),
@@ -124,6 +133,7 @@ async fn main() {
client_server_map,
transaction_mode,
default_server_role,
server_info,
)
.await
{

View File

@@ -8,10 +8,8 @@ use crate::errors::Error;
// This is a funny one. `psql` parses this to figure out which
// queries to send when using shortcuts, e.g. \d+.
//
// TODO: Actually get the version from the server itself.
//
const SERVER_VESION: &str = "12.9 (Ubuntu 12.9-0ubuntu0.20.04.1)";
// No longer used. Keeping it here until I'm sure we don't need it again.
const _SERVER_VESION: &str = "12.9 (Ubuntu 12.9-0ubuntu0.20.04.1)";
/// Tell the client that authentication handshake completed successfully.
pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
@@ -27,12 +25,12 @@ pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
/// Send server parameters to the client. This will tell the client
/// what server version and what's the encoding we're using.
//
// TODO: Forward these from the server instead of hardcoding.
// No longer used. Keeping it here until I'm sure we don't need it again.
//
pub async fn server_parameters(stream: &mut TcpStream) -> Result<(), Error> {
pub async fn _server_parameters(stream: &mut TcpStream) -> Result<(), Error> {
let client_encoding = BytesMut::from(&b"client_encoding\0UTF8\0"[..]);
let server_version =
BytesMut::from(&format!("server_version\0{}\0", SERVER_VESION).as_bytes()[..]);
BytesMut::from(&format!("server_version\0{}\0", _SERVER_VESION).as_bytes()[..]);
// Client encoding
let len = client_encoding.len() as i32 + 4; // TODO: add more parameters here

View File

@@ -1,6 +1,7 @@
/// Pooling and failover and banlist.
use async_trait::async_trait;
use bb8::{ManageConnection, Pool, PooledConnection};
use bytes::BytesMut;
use chrono::naive::NaiveDateTime;
use crate::config::{Address, Config, Role, User};
@@ -105,6 +106,38 @@ impl ConnectionPool {
}
}
/// Connect to all shards and grab server information.
/// Return server information we will pass to the clients
/// when they connect.
pub async fn validate(&mut self) -> Result<BytesMut, Error> {
let mut server_infos = Vec::new();
for shard in 0..self.shards() {
// TODO: query all primary and replicas in the shard configuration.
let connection = match self.get(Some(shard), None).await {
Ok(conn) => conn,
Err(err) => {
println!("> Shard {} down or misconfigured.", shard);
return Err(err);
}
};
let mut proxy = connection.0;
let _address = connection.1;
let server = &mut *proxy;
server_infos.push(server.server_info());
}
// TODO: compare server information to make sure
// all shards are running identical configurations.
if server_infos.len() == 0 {
return Err(Error::AllServersDown);
}
Ok(server_infos[0].clone())
}
/// Get a connection from the pool.
pub async fn get(
&mut self,