diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index 8a73102..9ba0686 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -34,6 +34,9 @@ psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_select.sql > / # Replica/primary selection & more sharding tests psql -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_test_primary_replica.sql > /dev/null +# Test reload config +kill -SIGHUP $(pgrep pgcat) + # # ActiveRecord tests! # diff --git a/Cargo.lock b/Cargo.lock index d11410a..befc8f1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "arc-swap" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5d78ce20460b82d3fa150275ed9d55e21064fc7951177baacf86a145c4a4b1f" + [[package]] name = "async-trait" version = "0.1.52" @@ -318,6 +324,7 @@ dependencies = [ name = "pgcat" version = "0.1.0" dependencies = [ + "arc-swap", "async-trait", "bb8", "bytes", diff --git a/Cargo.toml b/Cargo.toml index ba78b59..7afcc6a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,3 +23,4 @@ once_cell = "1" statsd = "0.15" sqlparser = "0.14" log = "0.4" +arc-swap = "1" diff --git a/README.md b/README.md index 1408d9b..cf955eb 100644 --- a/README.md +++ b/README.md @@ -9,19 +9,18 @@ Meow. PgBouncer rewritten in Rust, with sharding, load balancing and failover su **Alpha**: don't use in production just yet. ## Features - -| **Feature** | **Status** | **Comments** | -|--------------------------------|--------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------| -| Transaction pooling | :heavy_check_mark: | Identical to PgBouncer. | -| Session pooling | :heavy_check_mark: | Identical to PgBouncer. | -| `COPY` support | :heavy_check_mark: | Both `COPY TO` and `COPY FROM` are supported. | -| Query cancellation | :heavy_check_mark: | Supported both in transaction and session pooling modes. | -| Load balancing of read queries | :heavy_check_mark: | Using round-robin between replicas. Primary is included when `primary_reads_enabled` is enabled (default). | -| Sharding | :heavy_check_mark: | Transactions are sharded using `SET SHARD TO` and `SET SHARDING KEY TO` syntax extensions; see examples below. | -| Failover | :heavy_check_mark: | Replicas are tested with a health check. If a health check fails, remaining replicas are attempted; see below for algorithm description and examples. | -| Statistics reporting | :heavy_check_mark: | Statistics similar to PgBouncers are reported via StatsD. | -| Live configuration reloading | :x: :wrench: | On the roadmap; currently config changes require restart. | -| Client authentication | :x: :wrench: | On the roadmap; currently all clients are allowed to connect and one user is used to connect to Postgres. | +| **Feature** | **Status** | **Comments** | +|--------------------------------|-----------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------| +| Transaction pooling | :heavy_check_mark: | Identical to PgBouncer. | +| Session pooling | :heavy_check_mark: | Identical to PgBouncer. | +| `COPY` support | :heavy_check_mark: | Both `COPY TO` and `COPY FROM` are supported. | +| Query cancellation | :heavy_check_mark: | Supported both in transaction and session pooling modes. | +| Load balancing of read queries | :heavy_check_mark: | Using round-robin between replicas. Primary is included when `primary_reads_enabled` is enabled (default). | +| Sharding | :heavy_check_mark: | Transactions are sharded using `SET SHARD TO` and `SET SHARDING KEY TO` syntax extensions; see examples below. | +| Failover | :heavy_check_mark: | Replicas are tested with a health check. If a health check fails, remaining replicas are attempted; see below for algorithm description and examples. | +| Statistics reporting | :heavy_check_mark: | Statistics similar to PgBouncers are reported via StatsD. | +| Live configuration reloading | :construction_worker: | Reload config with a `SIGHUP` to the process, e.g. `kill -s SIGHUP $(pgrep pgcat)`. Not all settings can be reloaded without a restart. | +| Client authentication | :x: :wrench: | On the roadmap; currently all clients are allowed to connect and one user is used to connect to Postgres. | ## Deployment @@ -48,17 +47,17 @@ pgbench -t 1000 -p 6432 -h 127.0.0.1 --protocol extended See [sharding README](./tests/sharding/README.md) for sharding logic testing. -| **Feature** | **Tested in CI** | **Tested manually** | **Comments** | -|----------------------|--------------------|---------------------|--------------------------------------------------------------------------------------------------------------------------| -| Transaction pooling | :heavy_check_mark: | :heavy_check_mark: | Used by default for all tests. | -| Session pooling | :x: | :heavy_check_mark: | Easiest way to test is to enable it and run pgbench - results will be better than transaction pooling as expected. | -| `COPY` | :heavy_check_mark: | :heavy_check_mark: | `pgbench -i` uses `COPY`. `COPY FROM` is tested as well. | -| Query cancellation | :heavy_check_mark: | :heavy_check_mark: | `psql -c 'SELECT pg_sleep(1000);'` and press `Ctrl-C`. | -| Load balancing | :x: | :heavy_check_mark: | We could test this by emitting statistics for each replica and compare them. | -| Failover | :x: | :heavy_check_mark: | Misconfigure a replica in `pgcat.toml` and watch it forward queries to spares. CI testing could include using Toxiproxy. | -| Sharding | :heavy_check_mark: | :heavy_check_mark: | See `tests/sharding` and `tests/ruby` for an Rails/ActiveRecord example. | -| Statistics reporting | :x: | :heavy_check_mark: | Run `nc -l -u 8125` and watch the stats come in every 15 seconds. | - +| **Feature** | **Tested in CI** | **Tested manually** | **Comments** | +|-----------------------|--------------------|---------------------|--------------------------------------------------------------------------------------------------------------------------| +| Transaction pooling | :heavy_check_mark: | :heavy_check_mark: | Used by default for all tests. | +| Session pooling | :x: | :heavy_check_mark: | Easiest way to test is to enable it and run pgbench - results will be better than transaction pooling as expected. | +| `COPY` | :heavy_check_mark: | :heavy_check_mark: | `pgbench -i` uses `COPY`. `COPY FROM` is tested as well. | +| Query cancellation | :heavy_check_mark: | :heavy_check_mark: | `psql -c 'SELECT pg_sleep(1000);'` and press `Ctrl-C`. | +| Load balancing | :x: | :heavy_check_mark: | We could test this by emitting statistics for each replica and compare them. | +| Failover | :x: | :heavy_check_mark: | Misconfigure a replica in `pgcat.toml` and watch it forward queries to spares. CI testing could include using Toxiproxy. | +| Sharding | :heavy_check_mark: | :heavy_check_mark: | See `tests/sharding` and `tests/ruby` for an Rails/ActiveRecord example. | +| Statistics reporting | :x: | :heavy_check_mark: | Run `nc -l -u 8125` and watch the stats come in every 15 seconds. | +| Live config reloading | :heavy_check_mark: | :heavy_check_mark: | Run `kill -s SIGHUP $(pgrep pgcat)` and watch the config reload. | ## Usage @@ -173,6 +172,30 @@ SET SERVER ROLE TO 'auto'; -- let the query router figure out where the query sh SELECT * FROM users WHERE email = 'test@example.com'; -- shard setting lasts until set again; we are reading from the primary ``` +### Statistics reporting + +Stats are reported using StatsD every 15 seconds. The address is configurable with `statsd_address`, the default is `127.0.0.1:8125`. The stats are very similar to what Pgbouncer reports and the names are kept to be comparable. + +### Live configuration reloading + +The config can be reloaded by sending a `kill -s SIGHUP` to the process. Not all settings are currently supported by live reload: + +| **Config** | **Requires restart** | +|-------------------------|----------------------| +| `host` | yes | +| `port` | yes | +| `pool_mode` | no | +| `connect_timeout` | yes | +| `healthcheck_timeout` | no | +| `ban_time` | no | +| `statsd_address` | yes | +| `user` | yes | +| `shards` | yes | +| `default_role` | no | +| `primary_reads_enabled` | no | +| `query_parser_enabled` | no | + + ## Benchmarks You can setup PgBench locally through PgCat: diff --git a/pgcat.toml b/pgcat.toml index 0fa8b6a..9a240b3 100644 --- a/pgcat.toml +++ b/pgcat.toml @@ -29,6 +29,9 @@ healthcheck_timeout = 1000 # For how long to ban a server if it fails a health check (seconds). ban_time = 60 # Seconds +# Stats will be sent here +statsd_address = "127.0.0.1:8125" + # # User to use for authentication against the server. [user] diff --git a/src/client.rs b/src/client.rs index 834bde0..1b921b8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -10,6 +10,7 @@ use tokio::net::{ use std::collections::HashMap; +use crate::config::get_config; use crate::constants::*; use crate::errors::Error; use crate::messages::*; @@ -61,10 +62,12 @@ impl Client { pub async fn startup( mut stream: TcpStream, client_server_map: ClientServerMap, - transaction_mode: bool, server_info: BytesMut, stats: Reporter, ) -> Result { + let config = get_config(); + let transaction_mode = config.general.pool_mode.starts_with("t"); + drop(config); loop { // Could be StartupMessage or SSLRequest // which makes this variable length. @@ -154,11 +157,7 @@ impl Client { } /// Client loop. We handle all messages between the client and the database here. - pub async fn handle( - &mut self, - mut pool: ConnectionPool, - mut query_router: QueryRouter, - ) -> Result<(), Error> { + pub async fn handle(&mut self, mut pool: ConnectionPool) -> Result<(), Error> { // The client wants to cancel a query it has issued previously. if self.cancel_mode { let (process_id, secret_key, address, port) = { @@ -187,6 +186,8 @@ impl Client { return Ok(Server::cancel(&address, &port, process_id, secret_key).await?); } + let mut query_router = QueryRouter::new(); + // Our custom protocol loop. // We expect the client to either start a transaction with regular queries // or issue commands for our sharding and server selection protocols. diff --git a/src/config.rs b/src/config.rs index 8f7c45f..261d596 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,12 +1,17 @@ +use arc_swap::{ArcSwap, Guard}; +use once_cell::sync::Lazy; use serde_derive::Deserialize; use tokio::fs::File; use tokio::io::AsyncReadExt; use toml; use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use crate::errors::Error; +static CONFIG: Lazy> = Lazy::new(|| ArcSwap::from_pointee(Config::default())); + #[derive(Clone, PartialEq, Deserialize, Hash, std::cmp::Eq, Debug, Copy)] pub enum Role { Primary, @@ -39,12 +44,32 @@ pub struct Address { pub role: Role, } +impl Default for Address { + fn default() -> Address { + Address { + host: String::from("127.0.0.1"), + port: String::from("5432"), + shard: 0, + role: Role::Replica, + } + } +} + #[derive(Clone, PartialEq, Hash, std::cmp::Eq, Deserialize, Debug)] pub struct User { pub name: String, pub password: String, } +impl Default for User { + fn default() -> User { + User { + name: String::from("postgres"), + password: String::new(), + } + } +} + #[derive(Deserialize, Debug, Clone)] pub struct General { pub host: String, @@ -54,6 +79,22 @@ pub struct General { pub connect_timeout: u64, pub healthcheck_timeout: u64, pub ban_time: i64, + pub statsd_address: String, +} + +impl Default for General { + fn default() -> General { + General { + host: String::from("localhost"), + port: 5432, + pool_size: 15, + pool_mode: String::from("transaction"), + connect_timeout: 5000, + healthcheck_timeout: 1000, + ban_time: 60, + statsd_address: String::from("127.0.0.1:8125"), + } + } } #[derive(Deserialize, Debug, Clone)] @@ -62,6 +103,15 @@ pub struct Shard { pub database: String, } +impl Default for Shard { + fn default() -> Shard { + Shard { + servers: vec![(String::from("localhost"), 5432, String::from("primary"))], + database: String::from("postgres"), + } + } +} + #[derive(Deserialize, Debug, Clone)] pub struct QueryRouter { pub default_role: String, @@ -69,6 +119,16 @@ pub struct QueryRouter { pub primary_reads_enabled: bool, } +impl Default for QueryRouter { + fn default() -> QueryRouter { + QueryRouter { + default_role: String::from("any"), + query_parser_enabled: false, + primary_reads_enabled: true, + } + } +} + #[derive(Deserialize, Debug, Clone)] pub struct Config { pub general: General, @@ -77,8 +137,36 @@ pub struct Config { pub query_router: QueryRouter, } +impl Default for Config { + fn default() -> Config { + Config { + general: General::default(), + user: User::default(), + shards: HashMap::from([(String::from("1"), Shard::default())]), + query_router: QueryRouter::default(), + } + } +} + +impl Config { + pub fn show(&self) { + println!("> Pool size: {}", self.general.pool_size); + println!("> Pool mode: {}", self.general.pool_mode); + println!("> Ban time: {}s", self.general.ban_time); + println!( + "> Healthcheck timeout: {}ms", + self.general.healthcheck_timeout + ); + println!("> Connection timeout: {}ms", self.general.connect_timeout); + } +} + +pub fn get_config() -> Guard> { + CONFIG.load() +} + /// Parse the config. -pub async fn parse(path: &str) -> Result { +pub async fn parse(path: &str) -> Result<(), Error> { let mut contents = String::new(); let mut file = match File::open(path).await { Ok(file) => file, @@ -163,7 +251,9 @@ pub async fn parse(path: &str) -> Result { } }; - Ok(config) + CONFIG.store(Arc::new(config.clone())); + + Ok(()) } #[cfg(test)] @@ -172,11 +262,11 @@ mod test { #[tokio::test] async fn test_config() { - let config = parse("pgcat.toml").await.unwrap(); - assert_eq!(config.general.pool_size, 15); - assert_eq!(config.shards.len(), 3); - assert_eq!(config.shards["1"].servers[0].0, "127.0.0.1"); - assert_eq!(config.shards["0"].servers[0].2, "primary"); - assert_eq!(config.query_router.default_role, "any"); + parse("pgcat.toml").await.unwrap(); + assert_eq!(get_config().general.pool_size, 15); + assert_eq!(get_config().shards.len(), 3); + assert_eq!(get_config().shards["1"].servers[0].0, "127.0.0.1"); + assert_eq!(get_config().shards["0"].servers[0].2, "primary"); + assert_eq!(get_config().query_router.default_role, "any"); } } diff --git a/src/main.rs b/src/main.rs index 1669331..a12137b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ // You should have received a copy of the GNU General Public License // along with this program. If not, see . +extern crate arc_swap; extern crate async_trait; extern crate bb8; extern crate bytes; @@ -28,7 +29,10 @@ extern crate tokio; extern crate toml; use tokio::net::TcpListener; -use tokio::signal; +use tokio::{ + signal, + signal::unix::{signal as unix_signal, SignalKind}, +}; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -47,9 +51,8 @@ mod stats; // Support for query cancellation: this maps our process_ids and // secret keys to the backend's. -use config::Role; +use config::get_config; use pool::{ClientServerMap, ConnectionPool}; -use query_router::QueryRouter; use stats::{Collector, Reporter}; /// Main! @@ -63,14 +66,17 @@ async fn main() { return; } - let config = match config::parse("pgcat.toml").await { - Ok(config) => config, + // Prepare the config + match config::parse("pgcat.toml").await { + Ok(_) => (), Err(err) => { println!("> Config parse error: {:?}", err); return; } }; + let config = get_config(); + let addr = format!("{}:{}", config.general.host, config.general.port); let listener = match TcpListener::bind(&addr).await { Ok(sock) => sock, @@ -81,19 +87,11 @@ async fn main() { }; println!("> Running on {}", addr); + config.show(); // Tracks which client is connected to which server for query cancellation. let client_server_map: ClientServerMap = Arc::new(Mutex::new(HashMap::new())); - println!("> Pool size: {}", config.general.pool_size); - println!("> Pool mode: {}", config.general.pool_mode); - println!("> Ban time: {}s", config.general.ban_time); - println!( - "> Healthcheck timeout: {}ms", - config.general.healthcheck_timeout - ); - println!("> Connection timeout: {}ms", config.general.connect_timeout); - // Collect statistics and send them to StatsD let (tx, rx) = mpsc::channel(100); @@ -104,25 +102,8 @@ async fn main() { stats_collector.collect().await; }); - let mut pool = ConnectionPool::from_config( - config.clone(), - client_server_map.clone(), - Reporter::new(tx.clone()), - ) - .await; - - let transaction_mode = config.general.pool_mode == "transaction"; - let default_server_role = match config.query_router.default_role.as_ref() { - "any" => None, - "primary" => Some(Role::Primary), - "replica" => Some(Role::Replica), - _ => { - println!("> Config error, got unexpected query_router.default_role."); - return; - } - }; - let primary_reads_enabled = config.query_router.primary_reads_enabled; - let query_parser_enabled = config.query_router.query_parser_enabled; + let mut pool = + ConnectionPool::from_config(client_server_map.clone(), Reporter::new(tx.clone())).await; let server_info = match pool.validate().await { Ok(info) => info, @@ -156,26 +137,13 @@ async fn main() { println!(">> Client {:?} connected", addr); - match client::Client::startup( - socket, - client_server_map, - transaction_mode, - server_info, - reporter, - ) - .await + match client::Client::startup(socket, client_server_map, server_info, reporter) + .await { Ok(mut client) => { println!(">> Client {:?} authenticated successfully!", addr); - let query_router = QueryRouter::new( - default_server_role, - pool.shards(), - primary_reads_enabled, - query_parser_enabled, - ); - - match client.handle(pool, query_router).await { + match client.handle(pool).await { Ok(()) => { let duration = chrono::offset::Utc::now().naive_utc() - start; @@ -201,6 +169,26 @@ async fn main() { } }); + // Reload config + // kill -SIGHUP $(pgrep pgcat) + tokio::task::spawn(async move { + let mut stream = unix_signal(SignalKind::hangup()).unwrap(); + + loop { + stream.recv().await; + println!("> Reloading config"); + match config::parse("pgcat.toml").await { + Ok(_) => { + get_config().show(); + } + Err(err) => { + println!("> Config parse error: {:?}", err); + return; + } + }; + } + }); + // Setup shut down sequence match signal::ctrl_c().await { Ok(()) => { diff --git a/src/messages.rs b/src/messages.rs index 70dbc10..16b2f84 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -38,6 +38,7 @@ pub async fn backend_key_data( Ok(write_all(stream, key_data).await?) } +#[allow(dead_code)] pub fn simple_query(query: &str) -> BytesMut { let mut res = BytesMut::from(&b"Q"[..]); let query = format!("{}\0", query); diff --git a/src/pool.rs b/src/pool.rs index 6a190f1..29ee40b 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -4,7 +4,7 @@ use bb8::{ManageConnection, Pool, PooledConnection}; use bytes::BytesMut; use chrono::naive::NaiveDateTime; -use crate::config::{Address, Config, Role, User}; +use crate::config::{get_config, Address, Role, User}; use crate::errors::Error; use crate::server::Server; use crate::stats::Reporter; @@ -23,18 +23,16 @@ pub struct ConnectionPool { addresses: Vec>, round_robin: usize, banlist: BanList, - healthcheck_timeout: u64, - ban_time: i64, stats: Reporter, } impl ConnectionPool { /// Construct the connection pool from a config file. pub async fn from_config( - config: Config, client_server_map: ClientServerMap, stats: Reporter, ) -> ConnectionPool { + let config = get_config(); let mut shards = Vec::new(); let mut addresses = Vec::new(); let mut banlist = Vec::new(); @@ -103,8 +101,6 @@ impl ConnectionPool { addresses: addresses, round_robin: rand::random::() % address_len, // Start at a random replica banlist: Arc::new(Mutex::new(banlist)), - healthcheck_timeout: config.general.healthcheck_timeout, - ban_time: config.general.ban_time, stats: stats, } } @@ -214,9 +210,10 @@ impl ConnectionPool { // // Check if this server is alive with a health check let server = &mut *conn; + let healthcheck_timeout = get_config().general.healthcheck_timeout; match tokio::time::timeout( - tokio::time::Duration::from_millis(self.healthcheck_timeout), + tokio::time::Duration::from_millis(healthcheck_timeout), server.query("SELECT 1"), ) .await @@ -303,8 +300,9 @@ impl ConnectionPool { match guard[shard].get(address) { Some(timestamp) => { let now = chrono::offset::Utc::now().naive_utc(); + let config = get_config(); // Ban expired. - if now.timestamp() - timestamp.timestamp() > self.ban_time { + if now.timestamp() - timestamp.timestamp() > config.general.ban_time { guard[shard].remove(address); false } else { diff --git a/src/query_router.rs b/src/query_router.rs index c43ca49..097cdef 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -1,4 +1,4 @@ -use crate::config::Role; +use crate::config::{get_config, Role}; use crate::sharding::Sharder; /// Route queries automatically based on explicitely requested /// or implied query characteristics. @@ -65,20 +65,24 @@ impl QueryRouter { } } - pub fn new( - default_server_role: Option, - shards: usize, - primary_reads_enabled: bool, - query_parser_enabled: bool, - ) -> QueryRouter { + pub fn new() -> QueryRouter { + let config = get_config(); + + let default_server_role = match config.query_router.default_role.as_ref() { + "any" => None, + "primary" => Some(Role::Primary), + "replica" => Some(Role::Replica), + _ => unreachable!(), + }; + QueryRouter { default_server_role: default_server_role, - shards: shards, + shards: config.shards.len(), active_role: default_server_role, active_shard: None, - primary_reads_enabled: primary_reads_enabled, - query_parser_enabled: query_parser_enabled, + primary_reads_enabled: config.query_router.primary_reads_enabled, + query_parser_enabled: config.query_router.query_parser_enabled, } } @@ -275,6 +279,11 @@ impl QueryRouter { pub fn query_parser_enabled(&self) -> bool { self.query_parser_enabled } + + #[allow(dead_code)] + pub fn toggle_primary_reads(&mut self, value: bool) { + self.primary_reads_enabled = value; + } } #[cfg(test)] @@ -286,10 +295,7 @@ mod test { #[test] fn test_defaults() { QueryRouter::setup(); - - let default_server_role: Option = None; - let shards = 5; - let qr = QueryRouter::new(default_server_role, shards, false, false); + let qr = QueryRouter::new(); assert_eq!(qr.role(), None); } @@ -297,10 +303,10 @@ mod test { #[test] fn test_infer_role_replica() { QueryRouter::setup(); - - let default_server_role: Option = None; - let shards = 5; - let mut qr = QueryRouter::new(default_server_role, shards, false, false); + let mut qr = QueryRouter::new(); + assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None); + assert_eq!(qr.query_parser_enabled(), true); + qr.toggle_primary_reads(false); let queries = vec![ simple_query("SELECT * FROM items WHERE id = 5"), @@ -320,10 +326,7 @@ mod test { #[test] fn test_infer_role_primary() { QueryRouter::setup(); - - let default_server_role: Option = None; - let shards = 5; - let mut qr = QueryRouter::new(default_server_role, shards, false, false); + let mut qr = QueryRouter::new(); let queries = vec![ simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"), @@ -342,11 +345,9 @@ mod test { #[test] fn test_infer_role_primary_reads_enabled() { QueryRouter::setup(); - - let default_server_role: Option = None; - let shards = 5; - let mut qr = QueryRouter::new(default_server_role, shards, true, false); + let mut qr = QueryRouter::new(); let query = simple_query("SELECT * FROM items WHERE id = 5"); + qr.toggle_primary_reads(true); assert!(qr.infer_role(query)); assert_eq!(qr.role(), None); @@ -355,11 +356,9 @@ mod test { #[test] fn test_infer_role_parse_prepared() { QueryRouter::setup(); - - let default_server_role: Option = None; - let shards = 5; - - let mut query_router = QueryRouter::new(default_server_role, shards, false, false); + let mut qr = QueryRouter::new(); + qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")); + qr.toggle_primary_reads(false); let prepared_stmt = BytesMut::from( &b"WITH t AS (SELECT * FROM items WHERE name = $1) SELECT * FROM t WHERE id = $2\0"[..], @@ -370,8 +369,8 @@ mod test { res.put(prepared_stmt); res.put_i16(0); - assert!(query_router.infer_role(res)); - assert_eq!(query_router.role(), Some(Role::Replica)); + assert!(qr.infer_role(res)); + assert_eq!(qr.role(), Some(Role::Replica)); } #[test] @@ -411,15 +410,15 @@ mod test { #[test] fn test_try_execute_command() { QueryRouter::setup(); - let mut qr = QueryRouter::new(Some(Role::Primary), 5, false, false); + let mut qr = QueryRouter::new(); // SetShardingKey let query = simple_query("SET SHARDING KEY TO '13'"); assert_eq!( qr.try_execute_command(query), - Some((Command::SetShardingKey, String::from("3"))) + Some((Command::SetShardingKey, String::from("1"))) ); - assert_eq!(qr.shard(), 3); + assert_eq!(qr.shard(), 1); // SetShard let query = simple_query("SET SHARD TO '1'"); @@ -468,8 +467,9 @@ mod test { #[test] fn test_enable_query_parser() { QueryRouter::setup(); - let mut qr = QueryRouter::new(None, 5, false, false); + let mut qr = QueryRouter::new(); let query = simple_query("SET SERVER ROLE TO 'auto'"); + qr.toggle_primary_reads(false); assert!(qr.try_execute_command(query) != None); assert!(qr.query_parser_enabled()); diff --git a/src/stats.rs b/src/stats.rs index 3f93b3f..f3e0790 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -5,6 +5,8 @@ use tokio::sync::mpsc::{Receiver, Sender}; use std::collections::HashMap; use std::time::Instant; +use crate::config::get_config; + #[derive(Debug)] pub enum StatisticName { CheckoutTime, @@ -138,7 +140,7 @@ impl Collector { pub fn new(rx: Receiver) -> Collector { Collector { rx: rx, - client: Client::new("127.0.0.1:8125", "pgcat").unwrap(), + client: Client::new(&get_config().general.statsd_address, "pgcat").unwrap(), } }