some comments

This commit is contained in:
Lev Kokotov
2022-02-09 06:51:31 -08:00
parent 9fe50c48e8
commit 4c16ba3848
4 changed files with 20 additions and 9 deletions

View File

@@ -168,7 +168,7 @@ impl Client {
// Active shard we're talking to. // Active shard we're talking to.
// The lifetime of this depends on the pool mode: // The lifetime of this depends on the pool mode:
// - if in session mode, this lives until client disconnects or changes it, // - if in session mode, this lives until the client disconnects,
// - if in transaction mode, this lives for the duration of one transaction. // - if in transaction mode, this lives for the duration of one transaction.
let mut shard: Option<usize> = None; let mut shard: Option<usize> = None;
@@ -177,7 +177,7 @@ impl Client {
// either a `Q` (query) or `P` (prepare, extended protocol). // either a `Q` (query) or `P` (prepare, extended protocol).
// We can parse it here before grabbing a server from the pool, // We can parse it here before grabbing a server from the pool,
// in case the client is sending some control messages, e.g. // in case the client is sending some control messages, e.g.
// SET sharding_context.key = '1234'; // SET SHARDING KEY TO 'bigint';
let mut message = read_message(&mut self.read).await?; let mut message = read_message(&mut self.read).await?;
// Parse for special select shard command. // Parse for special select shard command.
@@ -191,9 +191,6 @@ impl Client {
None => (), None => (),
}; };
// The message is part of the regular protocol.
// self.buffer.put(message);
// Grab a server from the pool. // Grab a server from the pool.
// None = any shard // None = any shard
let connection = pool.get(shard).await.unwrap(); let connection = pool.get(shard).await.unwrap();
@@ -361,12 +358,19 @@ impl Client {
guard.remove(&(self.process_id, self.secret_key)); guard.remove(&(self.process_id, self.secret_key));
} }
/// Determine if the query is part of our special syntax, extract
/// the shard key, and return the shard to query based on Postgres'
/// PARTITION BY HASH function.
async fn select_shard(&mut self, mut buf: BytesMut, shards: usize) -> Option<usize> { async fn select_shard(&mut self, mut buf: BytesMut, shards: usize) -> Option<usize> {
let code = buf.get_u8() as char; let code = buf.get_u8() as char;
// Only supporting simpe protocol here, so
// one would have to execute something like this:
// psql -c "SET SHARDING KEY TO '1234'"
// after sanitizing the value manually, which can be just done with an
// int parser, e.g. `let key = "1234".parse::<i64>().unwrap()`.
match code { match code {
'Q' => (), 'Q' => (),
// 'P' => (),
_ => return None, _ => return None,
}; };

View File

@@ -43,8 +43,8 @@ pub struct Config {
pub shards: HashMap<String, Shard>, pub shards: HashMap<String, Shard>,
} }
/// Parse the config.
pub async fn parse(path: &str) -> Result<Config, Error> { pub async fn parse(path: &str) -> Result<Config, Error> {
// let path = Path::new(path);
let mut contents = String::new(); let mut contents = String::new();
let mut file = match File::open(path).await { let mut file = match File::open(path).await {
Ok(file) => file, Ok(file) => file,

View File

@@ -26,6 +26,9 @@ pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> {
/// Send server parameters to the client. This will tell the client /// Send server parameters to the client. This will tell the client
/// what server version and what's the encoding we're using. /// what server version and what's the encoding we're using.
//
// TODO: Forward these from the server instead of hardcoding.
//
pub async fn server_parameters(stream: &mut TcpStream) -> Result<(), Error> { pub async fn server_parameters(stream: &mut TcpStream) -> Result<(), Error> {
let client_encoding = BytesMut::from(&b"client_encoding\0UTF8\0"[..]); let client_encoding = BytesMut::from(&b"client_encoding\0UTF8\0"[..]);
let server_version = let server_version =
@@ -138,16 +141,20 @@ pub async fn md5_password(
Ok(write_all(stream, message).await?) Ok(write_all(stream, message).await?)
} }
/// Implements a response to our custom `SET SHARDING KEY` command.
/// This tells the client we're ready for the next query.
pub async fn set_sharding_key(stream: &mut OwnedWriteHalf) -> Result<(), Error> { pub async fn set_sharding_key(stream: &mut OwnedWriteHalf) -> Result<(), Error> {
let mut res = BytesMut::with_capacity(25); let mut res = BytesMut::with_capacity(25);
let set_complete = BytesMut::from(&"SET SHARDING KEY\0"[..]); let set_complete = BytesMut::from(&"SET SHARDING KEY\0"[..]);
let len = (set_complete.len() + 4) as i32; let len = (set_complete.len() + 4) as i32;
// CommandComplete
res.put_u8(b'C'); res.put_u8(b'C');
res.put_i32(len); res.put_i32(len);
res.put_slice(&set_complete[..]); res.put_slice(&set_complete[..]);
// ReadyForQuery (idle)
res.put_u8(b'Z'); res.put_u8(b'Z');
res.put_i32(5); res.put_i32(5);
res.put_u8(b'I'); res.put_u8(b'I');

View File

@@ -182,7 +182,7 @@ impl ConnectionPool {
pub fn is_banned(&self, address: &Address, shard: usize) -> bool { pub fn is_banned(&self, address: &Address, shard: usize) -> bool {
let mut guard = self.banlist.lock().unwrap(); let mut guard = self.banlist.lock().unwrap();
// Everything is banned, nothig is banned // Everything is banned = nothing is banned.
if guard[shard].len() == self.databases[shard].len() { if guard[shard].len() == self.databases[shard].len() {
guard[shard].clear(); guard[shard].clear();
drop(guard); drop(guard);
@@ -194,8 +194,8 @@ impl ConnectionPool {
match guard[shard].get(address) { match guard[shard].get(address) {
Some(timestamp) => { Some(timestamp) => {
let now = chrono::offset::Utc::now().naive_utc(); let now = chrono::offset::Utc::now().naive_utc();
// Ban expired.
if now.timestamp() - timestamp.timestamp() > self.ban_time { if now.timestamp() - timestamp.timestamp() > self.ban_time {
// 1 minute
guard[shard].remove(address); guard[shard].remove(address);
false false
} else { } else {