mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-23 01:16:30 +00:00
Pass real server info to the client (#10)
This commit is contained in:
@@ -63,6 +63,7 @@ impl Client {
|
||||
client_server_map: ClientServerMap,
|
||||
transaction_mode: bool,
|
||||
default_server_role: Option<Role>,
|
||||
server_info: BytesMut,
|
||||
) -> Result<Client, Error> {
|
||||
loop {
|
||||
// Could be StartupMessage or SSLRequest
|
||||
@@ -102,7 +103,7 @@ impl Client {
|
||||
let secret_key: i32 = rand::random();
|
||||
|
||||
auth_ok(&mut stream).await?;
|
||||
server_parameters(&mut stream).await?;
|
||||
write_all(&mut stream, server_info).await?;
|
||||
backend_key_data(&mut stream, process_id, secret_key).await?;
|
||||
ready_for_query(&mut stream).await?;
|
||||
|
||||
|
||||
@@ -90,6 +90,11 @@ pub async fn parse(path: &str) -> Result<Config, Error> {
|
||||
let mut dup_check = HashSet::new();
|
||||
let mut primary_count = 0;
|
||||
|
||||
if shard.1.servers.len() == 0 {
|
||||
println!("> Shard {} has no servers configured", shard.0);
|
||||
return Err(Error::BadConfig);
|
||||
}
|
||||
|
||||
for server in &shard.1.servers {
|
||||
dup_check.insert(server);
|
||||
|
||||
|
||||
12
src/main.rs
12
src/main.rs
@@ -86,7 +86,7 @@ async fn main() {
|
||||
);
|
||||
println!("> Connection timeout: {}ms", config.general.connect_timeout);
|
||||
|
||||
let pool = ConnectionPool::from_config(config.clone(), client_server_map.clone()).await;
|
||||
let mut pool = ConnectionPool::from_config(config.clone(), client_server_map.clone()).await;
|
||||
let transaction_mode = config.general.pool_mode == "transaction";
|
||||
let default_server_role = match config.query_router.default_role.as_ref() {
|
||||
"any" => None,
|
||||
@@ -98,11 +98,20 @@ async fn main() {
|
||||
}
|
||||
};
|
||||
|
||||
let server_info = match pool.validate().await {
|
||||
Ok(info) => info,
|
||||
Err(err) => {
|
||||
println!("> Could not validate connection pool: {:?}", err);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
println!("> Waiting for clients...");
|
||||
|
||||
loop {
|
||||
let pool = pool.clone();
|
||||
let client_server_map = client_server_map.clone();
|
||||
let server_info = server_info.clone();
|
||||
|
||||
let (socket, addr) = match listener.accept().await {
|
||||
Ok((socket, addr)) => (socket, addr),
|
||||
@@ -124,6 +133,7 @@ async fn main() {
|
||||
client_server_map,
|
||||
transaction_mode,
|
||||
default_server_role,
|
||||
server_info,
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
||||
@@ -8,10 +8,8 @@ use crate::errors::Error;
|
||||
|
||||
// This is a funny one. `psql` parses this to figure out which
|
||||
// queries to send when using shortcuts, e.g. \d+.
|
||||
//
|
||||
// TODO: Actually get the version from the server itself.
|
||||
//
|
||||
const SERVER_VESION: &str = "12.9 (Ubuntu 12.9-0ubuntu0.20.04.1)";
|
||||
// No longer used. Keeping it here until I'm sure we don't need it again.
|
||||
const _SERVER_VESION: &str = "12.9 (Ubuntu 12.9-0ubuntu0.20.04.1)";
|
||||
|
||||
/// Tell the client that authentication handshake completed successfully.
|
||||
pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
|
||||
@@ -27,12 +25,12 @@ pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
|
||||
/// Send server parameters to the client. This will tell the client
|
||||
/// what server version and what's the encoding we're using.
|
||||
//
|
||||
// TODO: Forward these from the server instead of hardcoding.
|
||||
// No longer used. Keeping it here until I'm sure we don't need it again.
|
||||
//
|
||||
pub async fn server_parameters(stream: &mut TcpStream) -> Result<(), Error> {
|
||||
pub async fn _server_parameters(stream: &mut TcpStream) -> Result<(), Error> {
|
||||
let client_encoding = BytesMut::from(&b"client_encoding\0UTF8\0"[..]);
|
||||
let server_version =
|
||||
BytesMut::from(&format!("server_version\0{}\0", SERVER_VESION).as_bytes()[..]);
|
||||
BytesMut::from(&format!("server_version\0{}\0", _SERVER_VESION).as_bytes()[..]);
|
||||
|
||||
// Client encoding
|
||||
let len = client_encoding.len() as i32 + 4; // TODO: add more parameters here
|
||||
|
||||
33
src/pool.rs
33
src/pool.rs
@@ -1,6 +1,7 @@
|
||||
/// Pooling and failover and banlist.
|
||||
use async_trait::async_trait;
|
||||
use bb8::{ManageConnection, Pool, PooledConnection};
|
||||
use bytes::BytesMut;
|
||||
use chrono::naive::NaiveDateTime;
|
||||
|
||||
use crate::config::{Address, Config, Role, User};
|
||||
@@ -105,6 +106,38 @@ impl ConnectionPool {
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect to all shards and grab server information.
|
||||
/// Return server information we will pass to the clients
|
||||
/// when they connect.
|
||||
pub async fn validate(&mut self) -> Result<BytesMut, Error> {
|
||||
let mut server_infos = Vec::new();
|
||||
|
||||
for shard in 0..self.shards() {
|
||||
// TODO: query all primary and replicas in the shard configuration.
|
||||
let connection = match self.get(Some(shard), None).await {
|
||||
Ok(conn) => conn,
|
||||
Err(err) => {
|
||||
println!("> Shard {} down or misconfigured.", shard);
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let mut proxy = connection.0;
|
||||
let _address = connection.1;
|
||||
let server = &mut *proxy;
|
||||
|
||||
server_infos.push(server.server_info());
|
||||
}
|
||||
|
||||
// TODO: compare server information to make sure
|
||||
// all shards are running identical configurations.
|
||||
if server_infos.len() == 0 {
|
||||
return Err(Error::AllServersDown);
|
||||
}
|
||||
|
||||
Ok(server_infos[0].clone())
|
||||
}
|
||||
|
||||
/// Get a connection from the pool.
|
||||
pub async fn get(
|
||||
&mut self,
|
||||
|
||||
Reference in New Issue
Block a user