Query parser 3.0 (#23)

* Starting query parsing

* Query parser

* working config

* disable by default

* fix tsets

* introducing log crate; test for query router; comments

* typo

* fixes for banning

* added test for prepared stmt
This commit is contained in:
Lev Kokotov
2022-02-18 07:10:18 -08:00
committed by GitHub
parent 4c8a3987fe
commit aa796289bf
8 changed files with 296 additions and 75 deletions

11
Cargo.lock generated
View File

@@ -322,6 +322,7 @@ dependencies = [
"bb8", "bb8",
"bytes", "bytes",
"chrono", "chrono",
"log",
"md-5", "md-5",
"num_cpus", "num_cpus",
"once_cell", "once_cell",
@@ -330,6 +331,7 @@ dependencies = [
"serde", "serde",
"serde_derive", "serde_derive",
"sha-1", "sha-1",
"sqlparser",
"statsd", "statsd",
"tokio", "tokio",
"toml", "toml",
@@ -492,6 +494,15 @@ version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83" checksum = "f2dd574626839106c320a323308629dcb1acfc96e32a8cba364ddc61ac23ee83"
[[package]]
name = "sqlparser"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8f192f29f4aa49e57bebd0aa05858e0a1f32dd270af36efe49edb82cbfffab6"
dependencies = [
"log",
]
[[package]] [[package]]
name = "statsd" name = "statsd"
version = "0.15.0" version = "0.15.0"

View File

@@ -21,3 +21,5 @@ regex = "1"
num_cpus = "1" num_cpus = "1"
once_cell = "1" once_cell = "1"
statsd = "0.15" statsd = "0.15"
sqlparser = "0.14"
log = "0.4"

View File

@@ -81,3 +81,15 @@ database = "shard2"
# replica: round-robin between replicas only without touching the primary, # replica: round-robin between replicas only without touching the primary,
# primary: all queries go to the primary unless otherwise specified. # primary: all queries go to the primary unless otherwise specified.
default_role = "any" default_role = "any"
# Query parser. If enabled, we'll attempt to parse
# every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
# we'll direct it to the primary.
query_parser_enabled = false
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write
# queries. The primary can always be explicitely selected with our custom protocol.
primary_reads_enabled = true

View File

@@ -10,7 +10,6 @@ use tokio::net::{
use std::collections::HashMap; use std::collections::HashMap;
use crate::config::Role;
use crate::constants::*; use crate::constants::*;
use crate::errors::Error; use crate::errors::Error;
use crate::messages::*; use crate::messages::*;
@@ -47,10 +46,6 @@ pub struct Client {
// to connect and cancel a query. // to connect and cancel a query.
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
// Unless client specifies, route queries to the servers that have this role,
// e.g. primary or replicas or any.
default_server_role: Option<Role>,
// Client parameters, e.g. user, client_encoding, etc. // Client parameters, e.g. user, client_encoding, etc.
#[allow(dead_code)] #[allow(dead_code)]
parameters: HashMap<String, String>, parameters: HashMap<String, String>,
@@ -67,7 +62,6 @@ impl Client {
mut stream: TcpStream, mut stream: TcpStream,
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
transaction_mode: bool, transaction_mode: bool,
default_server_role: Option<Role>,
server_info: BytesMut, server_info: BytesMut,
stats: Reporter, stats: Reporter,
) -> Result<Client, Error> { ) -> Result<Client, Error> {
@@ -126,7 +120,6 @@ impl Client {
process_id: process_id, process_id: process_id,
secret_key: secret_key, secret_key: secret_key,
client_server_map: client_server_map, client_server_map: client_server_map,
default_server_role: default_server_role,
parameters: parameters, parameters: parameters,
stats: stats, stats: stats,
}); });
@@ -148,7 +141,6 @@ impl Client {
process_id: process_id, process_id: process_id,
secret_key: secret_key, secret_key: secret_key,
client_server_map: client_server_map, client_server_map: client_server_map,
default_server_role: default_server_role,
parameters: HashMap::new(), parameters: HashMap::new(),
stats: stats, stats: stats,
}); });
@@ -162,7 +154,11 @@ impl Client {
} }
/// Client loop. We handle all messages between the client and the database here. /// Client loop. We handle all messages between the client and the database here.
pub async fn handle(&mut self, mut pool: ConnectionPool) -> Result<(), Error> { pub async fn handle(
&mut self,
mut pool: ConnectionPool,
mut query_router: QueryRouter,
) -> Result<(), Error> {
// The client wants to cancel a query it has issued previously. // The client wants to cancel a query it has issued previously.
if self.cancel_mode { if self.cancel_mode {
let (process_id, secret_key, address, port) = { let (process_id, secret_key, address, port) = {
@@ -191,8 +187,6 @@ impl Client {
return Ok(Server::cancel(&address, &port, process_id, secret_key).await?); return Ok(Server::cancel(&address, &port, process_id, secret_key).await?);
} }
let mut query_router = QueryRouter::new(self.default_server_role, pool.shards());
// Our custom protocol loop. // Our custom protocol loop.
// We expect the client to either start a transaction with regular queries // We expect the client to either start a transaction with regular queries
// or issue commands for our sharding and server selection protocols. // or issue commands for our sharding and server selection protocols.
@@ -222,6 +216,11 @@ impl Client {
continue; continue;
} }
// Attempt to parse the query to determine where it should go
if query_router.query_parser_enabled() && query_router.role() == None {
query_router.infer_role(message.clone());
}
// Grab a server from the pool: the client issued a regular query. // Grab a server from the pool: the client issued a regular query.
let connection = match pool.get(query_router.shard(), query_router.role()).await { let connection = match pool.get(query_router.shard(), query_router.role()).await {
Ok(conn) => conn, Ok(conn) => conn,

View File

@@ -13,6 +13,24 @@ pub enum Role {
Replica, Replica,
} }
impl PartialEq<Option<Role>> for Role {
fn eq(&self, other: &Option<Role>) -> bool {
match other {
None => true,
Some(role) => *self == *role,
}
}
}
impl PartialEq<Role> for Option<Role> {
fn eq(&self, other: &Role) -> bool {
match *self {
None => true,
Some(role) => role == *other,
}
}
}
#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Debug)] #[derive(Clone, PartialEq, Hash, std::cmp::Eq, Debug)]
pub struct Address { pub struct Address {
pub host: String, pub host: String,
@@ -47,6 +65,8 @@ pub struct Shard {
#[derive(Deserialize, Debug, Clone)] #[derive(Deserialize, Debug, Clone)]
pub struct QueryRouter { pub struct QueryRouter {
pub default_role: String, pub default_role: String,
pub query_parser_enabled: bool,
pub primary_reads_enabled: bool,
} }
#[derive(Deserialize, Debug, Clone)] #[derive(Deserialize, Debug, Clone)]

View File

@@ -16,11 +16,13 @@
extern crate async_trait; extern crate async_trait;
extern crate bb8; extern crate bb8;
extern crate bytes; extern crate bytes;
extern crate log;
extern crate md5; extern crate md5;
extern crate num_cpus; extern crate num_cpus;
extern crate once_cell; extern crate once_cell;
extern crate serde; extern crate serde;
extern crate serde_derive; extern crate serde_derive;
extern crate sqlparser;
extern crate statsd; extern crate statsd;
extern crate tokio; extern crate tokio;
extern crate toml; extern crate toml;
@@ -47,6 +49,7 @@ mod stats;
// secret keys to the backend's. // secret keys to the backend's.
use config::Role; use config::Role;
use pool::{ClientServerMap, ConnectionPool}; use pool::{ClientServerMap, ConnectionPool};
use query_router::QueryRouter;
use stats::{Collector, Reporter}; use stats::{Collector, Reporter};
/// Main! /// Main!
@@ -118,6 +121,8 @@ async fn main() {
return; return;
} }
}; };
let primary_reads_enabled = config.query_router.primary_reads_enabled;
let query_parser_enabled = config.query_router.query_parser_enabled;
let server_info = match pool.validate().await { let server_info = match pool.validate().await {
Ok(info) => info, Ok(info) => info,
@@ -155,7 +160,6 @@ async fn main() {
socket, socket,
client_server_map, client_server_map,
transaction_mode, transaction_mode,
default_server_role,
server_info, server_info,
reporter, reporter,
) )
@@ -164,7 +168,14 @@ async fn main() {
Ok(mut client) => { Ok(mut client) => {
println!(">> Client {:?} authenticated successfully!", addr); println!(">> Client {:?} authenticated successfully!", addr);
match client.handle(pool).await { let query_router = QueryRouter::new(
default_server_role,
pool.shards(),
primary_reads_enabled,
query_parser_enabled,
);
match client.handle(pool, query_router).await {
Ok(()) => { Ok(()) => {
let duration = chrono::offset::Utc::now().naive_utc() - start; let duration = chrono::offset::Utc::now().naive_utc() - start;

View File

@@ -25,7 +25,6 @@ pub struct ConnectionPool {
banlist: BanList, banlist: BanList,
healthcheck_timeout: u64, healthcheck_timeout: u64,
ban_time: i64, ban_time: i64,
pool_size: u32,
stats: Reporter, stats: Reporter,
} }
@@ -47,12 +46,12 @@ impl ConnectionPool {
.collect::<Vec<String>>(); .collect::<Vec<String>>();
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap()); shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
for shard in shard_ids { for shard_idx in shard_ids {
let shard = &config.shards[&shard]; let shard = &config.shards[&shard_idx];
let mut pools = Vec::new(); let mut pools = Vec::new();
let mut replica_addresses = Vec::new(); let mut replica_addresses = Vec::new();
for (idx, server) in shard.servers.iter().enumerate() { for server in shard.servers.iter() {
let role = match server.2.as_ref() { let role = match server.2.as_ref() {
"primary" => Role::Primary, "primary" => Role::Primary,
"replica" => Role::Replica, "replica" => Role::Replica,
@@ -66,7 +65,7 @@ impl ConnectionPool {
host: server.0.clone(), host: server.0.clone(),
port: server.1.to_string(), port: server.1.to_string(),
role: role, role: role,
shard: idx, shard: shard_idx.parse::<usize>().unwrap(),
}; };
let manager = ServerPool::new( let manager = ServerPool::new(
@@ -106,7 +105,6 @@ impl ConnectionPool {
banlist: Arc::new(Mutex::new(banlist)), banlist: Arc::new(Mutex::new(banlist)),
healthcheck_timeout: config.general.healthcheck_timeout, healthcheck_timeout: config.general.healthcheck_timeout,
ban_time: config.general.ban_time, ban_time: config.general.ban_time,
pool_size: config.general.pool_size,
stats: stats, stats: stats,
} }
} }
@@ -120,12 +118,12 @@ impl ConnectionPool {
let mut server_infos = Vec::new(); let mut server_infos = Vec::new();
for shard in 0..self.shards() { for shard in 0..self.shards() {
for _ in 0..self.replicas(shard) { for _ in 0..self.servers(shard) {
let connection = match self.get(shard, None).await { let connection = match self.get(shard, None).await {
Ok(conn) => conn, Ok(conn) => conn,
Err(err) => { Err(err) => {
println!("> Shard {} down or misconfigured.", shard); println!("> Shard {} down or misconfigured: {:?}", shard, err);
return Err(err); continue;
} }
}; };
@@ -152,8 +150,6 @@ impl ConnectionPool {
shard: usize, shard: usize,
role: Option<Role>, role: Option<Role>,
) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> { ) -> Result<(PooledConnection<'_, ServerPool>, Address), Error> {
// Set this to false to gain ~3-4% speed.
let with_health_check = true;
let now = Instant::now(); let now = Instant::now();
// We are waiting for a server now. // We are waiting for a server now.
@@ -161,25 +157,6 @@ impl ConnectionPool {
let addresses = &self.addresses[shard]; let addresses = &self.addresses[shard];
// Make sure if a specific role is requested, it's available in the pool.
match role {
Some(role) => {
let role_count = addresses.iter().filter(|&db| db.role == role).count();
if role_count == 0 {
println!(
">> Error: Role '{:?}' requested, but none are configured.",
role
);
return Err(Error::AllServersDown);
}
}
// Any role should be present.
_ => (),
};
let mut allowed_attempts = match role { let mut allowed_attempts = match role {
// Primary-specific queries get one attempt, if the primary is down, // Primary-specific queries get one attempt, if the primary is down,
// nothing we should do about it I think. It's dangerous to retry // nothing we should do about it I think. It's dangerous to retry
@@ -188,9 +165,22 @@ impl ConnectionPool {
// Replicas get to try as many times as there are replicas // Replicas get to try as many times as there are replicas
// and connections in the pool. // and connections in the pool.
_ => self.databases[shard].len() * self.pool_size as usize, _ => addresses.len(),
}; };
let exists = match role {
Some(role) => addresses.iter().filter(|addr| addr.role == role).count() > 0,
None => true,
};
if !exists {
log::error!(
"ConnectionPool::get Requested role {:?}, but none is configured.",
role
);
return Err(Error::BadConfig);
}
while allowed_attempts > 0 { while allowed_attempts > 0 {
// Round-robin each client's queries. // Round-robin each client's queries.
// If a client only sends one query and then disconnects, it doesn't matter // If a client only sends one query and then disconnects, it doesn't matter
@@ -200,23 +190,18 @@ impl ConnectionPool {
let address = &addresses[index]; let address = &addresses[index];
// Make sure you're getting a primary or a replica // Make sure you're getting a primary or a replica
// as per request. // as per request. If no specific role is requested, the first
match role { // available will be chosen.
Some(role) => { if address.role != role {
// Find the specific role the client wants in the pool.
if address.role != role {
continue;
}
}
None => (),
};
if self.is_banned(address, shard, role) {
continue; continue;
} }
allowed_attempts -= 1; allowed_attempts -= 1;
if self.is_banned(address, shard, role) {
continue;
}
// Check if we can connect // Check if we can connect
let mut conn = match self.databases[shard][index].get().await { let mut conn = match self.databases[shard][index].get().await {
Ok(conn) => conn, Ok(conn) => conn,
@@ -227,12 +212,6 @@ impl ConnectionPool {
} }
}; };
if !with_health_check {
self.stats.checkout_time(now.elapsed().as_micros());
self.stats.client_active();
return Ok((conn, address.clone()));
}
// // Check if this server is alive with a health check // // Check if this server is alive with a health check
let server = &mut *conn; let server = &mut *conn;
@@ -299,17 +278,21 @@ impl ConnectionPool {
/// Check if a replica can serve traffic. If all replicas are banned, /// Check if a replica can serve traffic. If all replicas are banned,
/// we unban all of them. Better to try then not to. /// we unban all of them. Better to try then not to.
pub fn is_banned(&self, address: &Address, shard: usize, role: Option<Role>) -> bool { pub fn is_banned(&self, address: &Address, shard: usize, role: Option<Role>) -> bool {
// If primary is requested explicitely, it can never be banned. let replicas_available = match role {
if Some(Role::Primary) == role { Some(Role::Replica) => self.addresses[shard]
return false; .iter()
} .filter(|addr| addr.role == Role::Replica)
.count(),
None => self.addresses[shard].len(),
Some(Role::Primary) => return false, // Primary cannot be banned.
};
// If you're not asking for the primary, // If you're not asking for the primary,
// all databases are treated as replicas. // all databases are treated as replicas.
let mut guard = self.banlist.lock().unwrap(); let mut guard = self.banlist.lock().unwrap();
// Everything is banned = nothing is banned. // Everything is banned = nothing is banned.
if guard[shard].len() == self.databases[shard].len() { if guard[shard].len() == replicas_available {
guard[shard].clear(); guard[shard].clear();
drop(guard); drop(guard);
println!(">> Unbanning all replicas."); println!(">> Unbanning all replicas.");
@@ -337,7 +320,7 @@ impl ConnectionPool {
self.databases.len() self.databases.len()
} }
pub fn replicas(&self, shard: usize) -> usize { pub fn servers(&self, shard: usize) -> usize {
self.addresses[shard].len() self.addresses[shard].len()
} }
} }

View File

@@ -1,8 +1,11 @@
use bytes::{Buf, BytesMut};
/// Route queries automatically based on explicitely requested /// Route queries automatically based on explicitely requested
/// or implied query characteristics. /// or implied query characteristics.
use bytes::{Buf, BytesMut};
use once_cell::sync::OnceCell; use once_cell::sync::OnceCell;
use regex::{Regex, RegexBuilder}; use regex::{Regex, RegexBuilder};
use sqlparser::ast::Statement::{Query, StartTransaction};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
use crate::config::Role; use crate::config::Role;
use crate::sharding::Sharder; use crate::sharding::Sharder;
@@ -26,6 +29,12 @@ pub struct QueryRouter {
// Should we be talking to a primary or a replica? // Should we be talking to a primary or a replica?
active_role: Option<Role>, active_role: Option<Role>,
// Include the primary into the replica pool?
primary_reads_enabled: bool,
// Should we try to parse queries?
query_parser_enabled: bool,
} }
impl QueryRouter { impl QueryRouter {
@@ -54,13 +63,20 @@ impl QueryRouter {
a && b a && b
} }
pub fn new(default_server_role: Option<Role>, shards: usize) -> QueryRouter { pub fn new(
default_server_role: Option<Role>,
shards: usize,
primary_reads_enabled: bool,
query_parser_enabled: bool,
) -> QueryRouter {
QueryRouter { QueryRouter {
default_server_role: default_server_role, default_server_role: default_server_role,
shards: shards, shards: shards,
active_role: default_server_role, active_role: default_server_role,
active_shard: None, active_shard: None,
primary_reads_enabled: primary_reads_enabled,
query_parser_enabled: query_parser_enabled,
} }
} }
@@ -109,7 +125,7 @@ impl QueryRouter {
} }
} }
// Pick a primary or a replica from the pool. /// Pick a primary or a replica from the pool.
pub fn select_role(&mut self, mut buf: BytesMut) -> bool { pub fn select_role(&mut self, mut buf: BytesMut) -> bool {
let code = buf.get_u8() as char; let code = buf.get_u8() as char;
@@ -150,6 +166,75 @@ impl QueryRouter {
} }
} }
/// Try to infer which server to connect to based on the contents of the query.
pub fn infer_role(&mut self, mut buf: BytesMut) -> bool {
let code = buf.get_u8() as char;
let len = buf.get_i32() as usize;
let query = match code {
'Q' => String::from_utf8_lossy(&buf[..len - 5]).to_string(),
'P' => {
let mut start = 0;
let mut end;
// Skip the name of the prepared statement.
while buf[start] != 0 && start < buf.len() {
start += 1;
}
start += 1; // Skip terminating null
// Find the end of the prepared stmt (\0)
end = start;
while buf[end] != 0 && end < buf.len() {
end += 1;
}
let query = String::from_utf8_lossy(&buf[start..end]).to_string();
query.replace("$", "") // Remove placeholders turning them into "values"
}
_ => return false,
};
let ast = match Parser::parse_sql(&PostgreSqlDialect {}, &query) {
Ok(ast) => ast,
Err(err) => {
log::debug!(
"QueryParser::infer_role could not parse query, error: {:?}, query: {}",
err,
query
);
return false;
}
};
if ast.len() == 0 {
return false;
}
match ast[0] {
// All transactions go to the primary, probably a write.
StartTransaction { .. } => {
self.active_role = Some(Role::Primary);
}
// Likely a read-only query
Query { .. } => {
self.active_role = match self.primary_reads_enabled {
false => Some(Role::Replica), // If primary should not be receiving reads, use a replica.
true => None, // Any server role is fine in this case.
}
}
// Likely a write
_ => {
self.active_role = Some(Role::Primary);
}
};
true
}
/// Get the current desired server role we should be talking to. /// Get the current desired server role we should be talking to.
pub fn role(&self) -> Option<Role> { pub fn role(&self) -> Option<Role> {
self.active_role self.active_role
@@ -169,6 +254,11 @@ impl QueryRouter {
self.active_role = self.default_server_role; self.active_role = self.default_server_role;
self.active_shard = None; self.active_shard = None;
} }
/// Should we attempt to parse queries?
pub fn query_parser_enabled(&self) -> bool {
self.query_parser_enabled
}
} }
#[cfg(test)] #[cfg(test)]
@@ -182,7 +272,7 @@ mod test {
let default_server_role: Option<Role> = None; let default_server_role: Option<Role> = None;
let shards = 5; let shards = 5;
let mut query_router = QueryRouter::new(default_server_role, shards); let mut query_router = QueryRouter::new(default_server_role, shards, false, false);
// Build the special syntax query. // Build the special syntax query.
let mut message = BytesMut::new(); let mut message = BytesMut::new();
@@ -205,7 +295,7 @@ mod test {
let default_server_role: Option<Role> = None; let default_server_role: Option<Role> = None;
let shards = 5; let shards = 5;
let mut query_router = QueryRouter::new(default_server_role, shards); let mut query_router = QueryRouter::new(default_server_role, shards, false, false);
// Build the special syntax query. // Build the special syntax query.
let mut message = BytesMut::new(); let mut message = BytesMut::new();
@@ -229,7 +319,7 @@ mod test {
let default_server_role: Option<Role> = None; let default_server_role: Option<Role> = None;
let shards = 5; let shards = 5;
let query_router = QueryRouter::new(default_server_role, shards); let query_router = QueryRouter::new(default_server_role, shards, false, false);
assert_eq!(query_router.shard(), 0); assert_eq!(query_router.shard(), 0);
assert_eq!(query_router.role(), None); assert_eq!(query_router.role(), None);
@@ -241,7 +331,7 @@ mod test {
let default_server_role: Option<Role> = None; let default_server_role: Option<Role> = None;
let shards = 5; let shards = 5;
let mut query_router = QueryRouter::new(default_server_role, shards); let mut query_router = QueryRouter::new(default_server_role, shards, false, false);
// Build the special syntax query. // Build the special syntax query.
let mut message = BytesMut::new(); let mut message = BytesMut::new();
@@ -256,4 +346,97 @@ mod test {
assert_eq!(query_router.select_shard(message.clone()), false); assert_eq!(query_router.select_shard(message.clone()), false);
assert_eq!(query_router.select_role(message.clone()), false); assert_eq!(query_router.select_role(message.clone()), false);
} }
#[test]
fn test_infer_role_replica() {
QueryRouter::setup();
let default_server_role: Option<Role> = None;
let shards = 5;
let mut query_router = QueryRouter::new(default_server_role, shards, false, false);
let queries = vec![
BytesMut::from(&b"SELECT * FROM items WHERE id = 5\0"[..]),
BytesMut::from(&b"SELECT id, name, value FROM items INNER JOIN prices ON item.id = prices.item_id\0"[..]),
BytesMut::from(&b"WITH t AS (SELECT * FROM items) SELECT * FROM t\0"[..]),
];
for query in &queries {
let mut res = BytesMut::from(&b"Q"[..]);
res.put_i32(query.len() as i32 + 4);
res.put(query.clone());
// It's a recognized query
assert!(query_router.infer_role(res));
assert_eq!(query_router.role(), Some(Role::Replica));
}
}
#[test]
fn test_infer_role_primary() {
QueryRouter::setup();
let default_server_role: Option<Role> = None;
let shards = 5;
let mut query_router = QueryRouter::new(default_server_role, shards, false, false);
let queries = vec![
BytesMut::from(&b"UPDATE items SET name = 'pumpkin' WHERE id = 5\0"[..]),
BytesMut::from(&b"INSERT INTO items (id, name) VALUES (5, 'pumpkin')\0"[..]),
BytesMut::from(&b"DELETE FROM items WHERE id = 5\0"[..]),
BytesMut::from(&b"BEGIN\0"[..]), // Transaction start
];
for query in &queries {
let mut res = BytesMut::from(&b"Q"[..]);
res.put_i32(query.len() as i32 + 4);
res.put(query.clone());
// It's a recognized query
assert!(query_router.infer_role(res));
assert_eq!(query_router.role(), Some(Role::Primary));
}
}
#[test]
fn test_infer_role_primary_reads_enabled() {
QueryRouter::setup();
let default_server_role: Option<Role> = None;
let shards = 5;
let mut query_router = QueryRouter::new(default_server_role, shards, true, false);
let query = BytesMut::from(&b"SELECT * FROM items WHERE id = 5\0"[..]);
let mut res = BytesMut::from(&b"Q"[..]);
res.put_i32(query.len() as i32 + 4);
res.put(query.clone());
assert!(query_router.infer_role(res));
assert_eq!(query_router.role(), None);
}
#[test]
fn test_infer_role_parse_prepared() {
QueryRouter::setup();
let default_server_role: Option<Role> = None;
let shards = 5;
let mut query_router = QueryRouter::new(default_server_role, shards, false, false);
let prepared_stmt = BytesMut::from(
&b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..],
);
let mut res = BytesMut::from(&b"P"[..]);
res.put_i32(prepared_stmt.len() as i32 + 4 + 1 + 2);
res.put_u8(0);
res.put(prepared_stmt);
res.put_i16(0);
assert!(query_router.infer_role(res));
assert_eq!(query_router.role(), Some(Role::Replica));
}
} }