Add support for multi-database / multi-user pools (#96)

* Add support for multi-database / multi-user pools

* Nothing

* cargo fmt

* CI

* remove test users

* rename pool

* Update tests to use admin user/pass

* more fixes

* Revert bad change

* Use PGDATABASE env var

* send server info in case of admin
This commit is contained in:
Mostafa Abdelraouf
2022-07-27 21:47:55 -05:00
committed by GitHub
parent c5be5565a5
commit 2ae4b438e3
14 changed files with 700 additions and 503 deletions

View File

@@ -5,21 +5,12 @@
#
# General pooler settings
[general]
# What IP to run on, 0.0.0.0 means accessible from everywhere.
host = "0.0.0.0"
# Port to run on, same as PgBouncer used in this example.
port = 6432
# How many connections to allocate per server.
pool_size = 15
# Pool mode (see PgBouncer docs for more).
# session: one server connection per connected client
# transaction: one server connection per client transaction
pool_mode = "transaction"
# How long to wait before aborting a server connection (ms).
connect_timeout = 100
@@ -29,56 +20,27 @@ healthcheck_timeout = 100
# For how long to ban a server if it fails a health check (seconds).
ban_time = 60 # Seconds
#
# Reload config automatically if it changes.
autoreload = true
# TLS
tls_certificate = ".circleci/server.cert"
tls_private_key = ".circleci/server.key"
#
# User to use for authentication against the server.
[user]
name = "sharding_user"
password = "sharding_user"
# Credentials to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
admin_username = "admin_user"
admin_password = "admin_pass"
#
# Shards in the cluster
[shards]
# Shard 0
[shards.0]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5433, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
]
# Database name (e.g. "postgres")
database = "shard0"
[shards.1]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5433, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
]
database = "shard1"
[shards.2]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5433, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
]
database = "shard2"
# Settings for our query routing layer.
[query_router]
# pool
# configs are structured as pool.<pool_name>
# the pool_name is what clients use as database name when connecting
# For the example below a client can connect using "postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded"
[pools.sharded_db]
# Pool mode (see PgBouncer docs for more).
# session: one server connection per connected client
# transaction: one server connection per client transaction
pool_mode = "transaction"
# If the client doesn't specify, route traffic to
# this role by default.
@@ -88,7 +50,6 @@ database = "shard2"
# primary: all queries go to the primary unless otherwise specified.
default_role = "any"
# Query parser. If enabled, we'll attempt to parse
# every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
@@ -109,3 +70,36 @@ primary_reads_enabled = true
# sha1: A hashing function based on SHA1
#
sharding_function = "pg_bigint_hash"
# Credentials for users that may connect to this cluster
[pools.sharded_db.users.0]
username = "sharding_user"
password = "sharding_user"
# Maximum number of server connections that can be established for this user
# The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users.
pool_size = 9
# Shard 0
[pools.sharded_db.shards.0]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ]
]
# Database name (e.g. "postgres")
database = "shard0"
[pools.sharded_db.shards.1]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard1"
[pools.sharded_db.shards.2]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard2"

View File

@@ -32,6 +32,7 @@ toxiproxy-cli create -l 127.0.0.1:5433 -u 127.0.0.1:5432 postgres_replica
start_pgcat "info"
export PGPASSWORD=sharding_user
export PGDATABASE=sharded_db
# pgbench test
pgbench -U sharding_user -i -h 127.0.0.1 -p 6432
@@ -47,7 +48,7 @@ sleep 1
killall psql -s SIGINT
# Reload pool (closing unused server connections)
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD'
PGPASSWORD=admin_pass psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD'
(psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(50)' || true) &
sleep 1
@@ -72,15 +73,17 @@ cd tests/ruby && \
cd ../..
# Admin tests
psql -U sharding_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW LISTS' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW POOLS' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW VERSION' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev/null # will ignore
(! psql -U sharding_user -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null)
export PGPASSWORD=admin_pass
psql -U admin_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null
psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null
psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null
psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW LISTS' > /dev/null
psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW POOLS' > /dev/null
psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW VERSION' > /dev/null
psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev/null # will ignore
(! psql -U admin_user -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null)
export PGPASSWORD=sharding_user
# Start PgCat in debug to demonstrate failover better
start_pgcat "trace"

2
Cargo.lock generated
View File

@@ -395,7 +395,7 @@ dependencies = [
[[package]]
name = "pgcat"
version = "0.4.0-beta1"
version = "0.6.0-alpha1"
dependencies = [
"arc-swap",
"async-trait",

View File

@@ -1,6 +1,6 @@
[package]
name = "pgcat"
version = "0.4.0-beta1"
version = "0.6.0-alpha1"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

View File

@@ -5,21 +5,12 @@
#
# General pooler settings
[general]
# What IP to run on, 0.0.0.0 means accessible from everywhere.
host = "0.0.0.0"
# Port to run on, same as PgBouncer used in this example.
port = 6432
# How many connections to allocate per server.
pool_size = 15
# Pool mode (see PgBouncer docs for more).
# session: one server connection per connected client
# transaction: one server connection per client transaction
pool_mode = "transaction"
# How long to wait before aborting a server connection (ms).
connect_timeout = 5000
@@ -27,7 +18,7 @@ connect_timeout = 5000
healthcheck_timeout = 1000
# For how long to ban a server if it fails a health check (seconds).
ban_time = 60 # Seconds
ban_time = 60 # seconds
# Reload config automatically if it changes.
autoreload = false
@@ -36,50 +27,20 @@ autoreload = false
# tls_certificate = "server.cert"
# tls_private_key = "server.key"
#
# User to use for authentication against the server.
[user]
name = "sharding_user"
password = "sharding_user"
# Credentials to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
admin_username = "user"
admin_password = "pass"
#
# Shards in the cluster
[shards]
# Shard 0
[shards.0]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
]
# Database name (e.g. "postgres")
database = "shard0"
[shards.1]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
]
database = "shard1"
[shards.2]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
]
database = "shard2"
# Settings for our query routing layer.
[query_router]
# pool
# configs are structured as pool.<pool_name>
# the pool_name is what clients use as database name when connecting
# For the example below a client can connect using "postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded"
[pools.sharded]
# Pool mode (see PgBouncer docs for more).
# session: one server connection per connected client
# transaction: one server connection per client transaction
pool_mode = "transaction"
# If the client doesn't specify, route traffic to
# this role by default.
@@ -89,7 +50,6 @@ database = "shard2"
# primary: all queries go to the primary unless otherwise specified.
default_role = "any"
# Query parser. If enabled, we'll attempt to parse
# every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
@@ -110,3 +70,61 @@ primary_reads_enabled = true
# sha1: A hashing function based on SHA1
#
sharding_function = "pg_bigint_hash"
# Credentials for users that may connect to this cluster
[pools.sharded.users.0]
username = "sharding_user"
password = "sharding_user"
# Maximum number of server connections that can be established for this user
# The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users.
pool_size = 9
[pools.sharded.users.1]
username = "other_user"
password = "other_user"
pool_size = 21
# Shard 0
[pools.sharded.shards.0]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ]
]
# Database name (e.g. "postgres")
database = "shard0"
[pools.sharded.shards.1]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard1"
[pools.sharded.shards.2]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard2"
[pools.simple_db]
pool_mode = "session"
default_role = "primary"
query_parser_enabled = true
primary_reads_enabled = true
sharding_function = "pg_bigint_hash"
[pools.simple_db.users.0]
username = "simple_user"
password = "simple_user"
pool_size = 5
[pools.simple_db.shards.0]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ]
]
database = "some_db"

View File

@@ -3,10 +3,10 @@ use bytes::{Buf, BufMut, BytesMut};
use log::{info, trace};
use std::collections::HashMap;
use crate::config::{get_config, reload_config};
use crate::config::{get_config, reload_config, VERSION};
use crate::errors::Error;
use crate::messages::*;
use crate::pool::ConnectionPool;
use crate::pool::get_all_pools;
use crate::stats::get_stats;
use crate::ClientServerMap;
@@ -14,7 +14,6 @@ use crate::ClientServerMap;
pub async fn handle_admin<T>(
stream: &mut T,
mut query: BytesMut,
pool: ConnectionPool,
client_server_map: ClientServerMap,
) -> Result<(), Error>
where
@@ -35,7 +34,7 @@ where
if query.starts_with("SHOW STATS") {
trace!("SHOW STATS");
show_stats(stream, &pool).await
show_stats(stream).await
} else if query.starts_with("RELOAD") {
trace!("RELOAD");
reload(stream, client_server_map).await
@@ -44,13 +43,13 @@ where
show_config(stream).await
} else if query.starts_with("SHOW DATABASES") {
trace!("SHOW DATABASES");
show_databases(stream, &pool).await
show_databases(stream).await
} else if query.starts_with("SHOW POOLS") {
trace!("SHOW POOLS");
show_pools(stream, &pool).await
show_pools(stream).await
} else if query.starts_with("SHOW LISTS") {
trace!("SHOW LISTS");
show_lists(stream, &pool).await
show_lists(stream).await
} else if query.starts_with("SHOW VERSION") {
trace!("SHOW VERSION");
show_version(stream).await
@@ -63,7 +62,7 @@ where
}
/// Column-oriented statistics.
async fn show_lists<T>(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error>
async fn show_lists<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
@@ -71,17 +70,20 @@ where
let columns = vec![("list", DataType::Text), ("items", DataType::Int4)];
let mut users = 1;
let mut databases = 1;
for (_, pool) in get_all_pools() {
databases += pool.databases();
users += 1; // One user per pool
}
let mut res = BytesMut::new();
res.put(row_description(&columns));
res.put(data_row(&vec![
"databases".to_string(),
(pool.databases() + 1).to_string(), // see comment below
databases.to_string(),
]));
res.put(data_row(&vec!["users".to_string(), "1".to_string()]));
res.put(data_row(&vec![
"pools".to_string(),
(pool.databases() + 1).to_string(), // +1 for the pgbouncer admin db pool which isn't real
])); // but admin tools that work with pgbouncer want this
res.put(data_row(&vec!["users".to_string(), users.to_string()]));
res.put(data_row(&vec!["pools".to_string(), databases.to_string()]));
res.put(data_row(&vec![
"free_clients".to_string(),
stats
@@ -140,7 +142,7 @@ where
let mut res = BytesMut::new();
res.put(row_description(&vec![("version", DataType::Text)]));
res.put(data_row(&vec!["PgCat 0.1.0".to_string()]));
res.put(data_row(&vec![format!("PgCat {}", VERSION).to_string()]));
res.put(command_complete("SHOW"));
res.put_u8(b'Z');
@@ -151,12 +153,11 @@ where
}
/// Show utilization of connection pools for each shard and replicas.
async fn show_pools<T>(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error>
async fn show_pools<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let stats = get_stats();
let config = get_config();
let columns = vec![
("database", DataType::Text),
@@ -176,24 +177,26 @@ where
let mut res = BytesMut::new();
res.put(row_description(&columns));
for (_, pool) in get_all_pools() {
let pool_config = &pool.settings;
for shard in 0..pool.shards() {
for server in 0..pool.servers(shard) {
let address = pool.address(shard, server);
let stats = match stats.get(&address.id) {
Some(stats) => stats.clone(),
None => HashMap::new(),
};
for shard in 0..pool.shards() {
for server in 0..pool.servers(shard) {
let address = pool.address(shard, server);
let stats = match stats.get(&address.id) {
Some(stats) => stats.clone(),
None => HashMap::new(),
};
let mut row = vec![address.name(), pool_config.user.username.clone()];
let mut row = vec![address.name(), config.user.name.clone()];
for column in &columns[2..columns.len() - 1] {
let value = stats.get(column.0).unwrap_or(&0).to_string();
row.push(value);
}
for column in &columns[2..columns.len() - 1] {
let value = stats.get(column.0).unwrap_or(&0).to_string();
row.push(value);
row.push(pool_config.pool_mode.to_string());
res.put(data_row(&row));
}
row.push(config.general.pool_mode.to_string());
res.put(data_row(&row));
}
}
@@ -208,12 +211,10 @@ where
}
/// Show shards and replicas.
async fn show_databases<T>(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error>
async fn show_databases<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
let config = get_config();
// Columns
let columns = vec![
("name", DataType::Text),
@@ -235,31 +236,33 @@ where
res.put(row_description(&columns));
for shard in 0..pool.shards() {
let database_name = &config.shards[&shard.to_string()].database;
for (_, pool) in get_all_pools() {
let pool_config = pool.settings.clone();
for shard in 0..pool.shards() {
let database_name = &pool_config.shards[&shard.to_string()].database;
for server in 0..pool.servers(shard) {
let address = pool.address(shard, server);
let pool_state = pool.pool_state(shard, server);
for server in 0..pool.servers(shard) {
let address = pool.address(shard, server);
let pool_state = pool.pool_state(shard, server);
res.put(data_row(&vec![
address.name(), // name
address.host.to_string(), // host
address.port.to_string(), // port
database_name.to_string(), // database
config.user.name.to_string(), // force_user
config.general.pool_size.to_string(), // pool_size
"0".to_string(), // min_pool_size
"0".to_string(), // reserve_pool
config.general.pool_mode.to_string(), // pool_mode
config.general.pool_size.to_string(), // max_connections
pool_state.connections.to_string(), // current_connections
"0".to_string(), // paused
"0".to_string(), // disabled
]));
res.put(data_row(&vec![
address.name(), // name
address.host.to_string(), // host
address.port.to_string(), // port
database_name.to_string(), // database
pool_config.user.username.to_string(), // force_user
pool_config.user.pool_size.to_string(), // pool_size
"0".to_string(), // min_pool_size
"0".to_string(), // reserve_pool
pool_config.pool_mode.to_string(), // pool_mode
pool_config.user.pool_size.to_string(), // max_connections
pool_state.connections.to_string(), // current_connections
"0".to_string(), // paused
"0".to_string(), // disabled
]));
}
}
}
res.put(command_complete("SHOW"));
// ReadyForQuery
@@ -349,7 +352,7 @@ where
}
/// Show shard and replicas statistics.
async fn show_stats<T>(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error>
async fn show_stats<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
{
@@ -375,21 +378,23 @@ where
let mut res = BytesMut::new();
res.put(row_description(&columns));
for shard in 0..pool.shards() {
for server in 0..pool.servers(shard) {
let address = pool.address(shard, server);
let stats = match stats.get(&address.id) {
Some(stats) => stats.clone(),
None => HashMap::new(),
};
for (_, pool) in get_all_pools() {
for shard in 0..pool.shards() {
for server in 0..pool.servers(shard) {
let address = pool.address(shard, server);
let stats = match stats.get(&address.id) {
Some(stats) => stats.clone(),
None => HashMap::new(),
};
let mut row = vec![address.name()];
let mut row = vec![address.name()];
for column in &columns[1..] {
row.push(stats.get(column.0).unwrap_or(&0).to_string());
for column in &columns[1..] {
row.push(stats.get(column.0).unwrap_or(&0).to_string());
}
res.put(data_row(&row));
}
res.put(data_row(&row));
}
}

View File

@@ -10,7 +10,7 @@ use crate::config::get_config;
use crate::constants::*;
use crate::errors::Error;
use crate::messages::*;
use crate::pool::{get_pool, ClientServerMap};
use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::query_router::{Command, QueryRouter};
use crate::server::Server;
use crate::stats::{get_reporter, Reporter};
@@ -71,6 +71,8 @@ pub struct Client<S, T> {
/// Last server process id we talked to.
last_server_id: Option<i32>,
target_pool: ConnectionPool,
}
/// Client entrypoint.
@@ -258,11 +260,25 @@ where
client_server_map: ClientServerMap,
) -> Result<Client<S, T>, Error> {
let config = get_config();
let transaction_mode = config.general.pool_mode == "transaction";
let stats = get_reporter();
trace!("Got StartupMessage");
let parameters = parse_startup(bytes.clone())?;
let database = match parameters.get("database") {
Some(db) => db,
None => return Err(Error::ClientError),
};
let user = match parameters.get("user") {
Some(user) => user,
None => return Err(Error::ClientError),
};
let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == &database)
.count()
== 1;
// Generate random backend ID and secret key
let process_id: i32 = rand::random();
@@ -295,33 +311,57 @@ where
Err(_) => return Err(Error::SocketError),
};
// Compare server and client hashes.
let password_hash = md5_hash_password(&config.user.name, &config.user.password, &salt);
let mut target_pool: ConnectionPool = ConnectionPool::default();
let mut transaction_mode = false;
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, &config.user.name).await?;
return Err(Error::ClientError);
if admin {
let correct_user = config.general.admin_username.as_str();
let correct_password = config.general.admin_password.as_str();
// Compare server and client hashes.
let password_hash = md5_hash_password(correct_user, correct_password, &salt);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, user).await?;
return Err(Error::ClientError);
}
} else {
target_pool = match get_pool(database.clone(), user.clone()) {
Some(pool) => pool,
None => {
error_response(
&mut write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
database, user
),
)
.await?;
return Err(Error::ClientError);
}
};
transaction_mode = target_pool.settings.pool_mode == "transaction";
// Compare server and client hashes.
let correct_password = target_pool.settings.user.password.as_str();
let password_hash = md5_hash_password(user, correct_password, &salt);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, user).await?;
return Err(Error::ClientError);
}
}
debug!("Password authentication successful");
auth_ok(&mut write).await?;
write_all(&mut write, get_pool().server_info()).await?;
write_all(&mut write, target_pool.server_info()).await?;
backend_key_data(&mut write, process_id, secret_key).await?;
ready_for_query(&mut write).await?;
trace!("Startup OK");
let database = parameters
.get("database")
.unwrap_or(parameters.get("user").unwrap());
let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == &database)
.count()
== 1;
// Split the read and write streams
// so we can control buffering.
@@ -335,11 +375,12 @@ where
process_id: process_id,
secret_key: secret_key,
client_server_map: client_server_map,
parameters: parameters,
parameters: parameters.clone(),
stats: stats,
admin: admin,
last_address_id: None,
last_server_id: None,
target_pool: target_pool,
});
}
@@ -353,26 +394,22 @@ where
) -> Result<Client<S, T>, Error> {
let process_id = bytes.get_i32();
let secret_key = bytes.get_i32();
let config = get_config();
let transaction_mode = config.general.pool_mode == "transaction";
let stats = get_reporter();
return Ok(Client {
read: BufReader::new(read),
write: write,
addr,
buffer: BytesMut::with_capacity(8196),
cancel_mode: true,
transaction_mode: transaction_mode,
transaction_mode: false,
process_id: process_id,
secret_key: secret_key,
client_server_map: client_server_map,
parameters: HashMap::new(),
stats: stats,
stats: get_reporter(),
admin: false,
last_address_id: None,
last_server_id: None,
target_pool: ConnectionPool::default(),
});
}
@@ -410,7 +447,7 @@ where
// The query router determines where the query is going to go,
// e.g. primary, replica, which shard.
let mut query_router = QueryRouter::new();
let mut query_router = QueryRouter::new(self.target_pool.clone());
let mut round_robin = 0;
// Our custom protocol loop.
@@ -432,7 +469,7 @@ where
// Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config
// when starting a query.
let mut pool = get_pool();
let mut pool = self.target_pool.clone();
// Avoid taking a server if the client just wants to disconnect.
if message[0] as char == 'X' {
@@ -443,13 +480,7 @@ where
// Handle admin database queries.
if self.admin {
debug!("Handling admin command");
handle_admin(
&mut self.write,
message,
pool.clone(),
self.client_server_map.clone(),
)
.await?;
handle_admin(&mut self.write, message, self.client_server_map.clone()).await?;
continue;
}

View File

@@ -4,6 +4,7 @@ use log::{error, info};
use once_cell::sync::Lazy;
use serde_derive::Deserialize;
use std::collections::{HashMap, HashSet};
use std::hash::Hash;
use std::path::Path;
use std::sync::Arc;
use tokio::fs::File;
@@ -14,6 +15,8 @@ use crate::errors::Error;
use crate::tls::{load_certs, load_keys};
use crate::{ClientServerMap, ConnectionPool};
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Globally available configuration.
static CONFIG: Lazy<ArcSwap<Config>> = Lazy::new(|| ArcSwap::from_pointee(Config::default()));
@@ -58,6 +61,7 @@ pub struct Address {
pub host: String,
pub port: String,
pub shard: usize,
pub database: String,
pub role: Role,
pub replica_number: usize,
}
@@ -70,6 +74,7 @@ impl Default for Address {
port: String::from("5432"),
shard: 0,
replica_number: 0,
database: String::from("database"),
role: Role::Replica,
}
}
@@ -79,9 +84,12 @@ impl Address {
/// Address name (aka database) used in `SHOW STATS`, `SHOW DATABASES`, and `SHOW POOLS`.
pub fn name(&self) -> String {
match self.role {
Role::Primary => format!("shard_{}_primary", self.shard),
Role::Primary => format!("{}_shard_{}_primary", self.database, self.shard),
Role::Replica => format!("shard_{}_replica_{}", self.shard, self.replica_number),
Role::Replica => format!(
"{}_shard_{}_replica_{}",
self.database, self.shard, self.replica_number
),
}
}
}
@@ -89,15 +97,17 @@ impl Address {
/// PostgreSQL user.
#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Deserialize, Debug)]
pub struct User {
pub name: String,
pub username: String,
pub password: String,
pub pool_size: u32,
}
impl Default for User {
fn default() -> User {
User {
name: String::from("postgres"),
username: String::from("postgres"),
password: String::new(),
pool_size: 15,
}
}
}
@@ -107,14 +117,14 @@ impl Default for User {
pub struct General {
pub host: String,
pub port: i16,
pub pool_size: u32,
pub pool_mode: String,
pub connect_timeout: u64,
pub healthcheck_timeout: u64,
pub ban_time: i64,
pub autoreload: bool,
pub tls_certificate: Option<String>,
pub tls_private_key: Option<String>,
pub admin_username: String,
pub admin_password: String,
}
impl Default for General {
@@ -122,14 +132,37 @@ impl Default for General {
General {
host: String::from("localhost"),
port: 5432,
pool_size: 15,
pool_mode: String::from("transaction"),
connect_timeout: 5000,
healthcheck_timeout: 1000,
ban_time: 60,
autoreload: false,
tls_certificate: None,
tls_private_key: None,
admin_username: String::from("admin"),
admin_password: String::from("admin"),
}
}
}
#[derive(Deserialize, Debug, Clone, PartialEq)]
pub struct Pool {
pub pool_mode: String,
pub shards: HashMap<String, Shard>,
pub users: HashMap<String, User>,
pub default_role: String,
pub query_parser_enabled: bool,
pub primary_reads_enabled: bool,
pub sharding_function: String,
}
impl Default for Pool {
fn default() -> Pool {
Pool {
pool_mode: String::from("transaction"),
shards: HashMap::from([(String::from("1"), Shard::default())]),
users: HashMap::default(),
default_role: String::from("any"),
query_parser_enabled: false,
primary_reads_enabled: true,
sharding_function: "pg_bigint_hash".to_string(),
}
}
}
@@ -137,8 +170,8 @@ impl Default for General {
/// Shard configuration.
#[derive(Deserialize, Debug, Clone, PartialEq)]
pub struct Shard {
pub servers: Vec<(String, u16, String)>,
pub database: String,
pub servers: Vec<(String, u16, String)>,
}
impl Default for Shard {
@@ -150,26 +183,6 @@ impl Default for Shard {
}
}
/// Query Router configuration.
#[derive(Deserialize, Debug, Clone, PartialEq)]
pub struct QueryRouter {
pub default_role: String,
pub query_parser_enabled: bool,
pub primary_reads_enabled: bool,
pub sharding_function: String,
}
impl Default for QueryRouter {
fn default() -> QueryRouter {
QueryRouter {
default_role: String::from("any"),
query_parser_enabled: false,
primary_reads_enabled: true,
sharding_function: "pg_bigint_hash".to_string(),
}
}
}
fn default_path() -> String {
String::from("pgcat.toml")
}
@@ -181,9 +194,7 @@ pub struct Config {
pub path: String,
pub general: General,
pub user: User,
pub shards: HashMap<String, Shard>,
pub query_router: QueryRouter,
pub pools: HashMap<String, Pool>,
}
impl Default for Config {
@@ -191,26 +202,58 @@ impl Default for Config {
Config {
path: String::from("pgcat.toml"),
general: General::default(),
user: User::default(),
shards: HashMap::from([(String::from("1"), Shard::default())]),
query_router: QueryRouter::default(),
pools: HashMap::default(),
}
}
}
impl From<&Config> for std::collections::HashMap<String, String> {
fn from(config: &Config) -> HashMap<String, String> {
HashMap::from([
let mut r: Vec<(String, String)> = config
.pools
.iter()
.flat_map(|(pool_name, pool)| {
[
(
format!("pools.{}.pool_mode", pool_name),
pool.pool_mode.clone(),
),
(
format!("pools.{}.primary_reads_enabled", pool_name),
pool.primary_reads_enabled.to_string(),
),
(
format!("pools.{}.query_parser_enabled", pool_name),
pool.query_parser_enabled.to_string(),
),
(
format!("pools.{}.default_role", pool_name),
pool.default_role.clone(),
),
(
format!("pools.{}.sharding_function", pool_name),
pool.sharding_function.clone(),
),
(
format!("pools.{:?}.shard_count", pool_name),
pool.shards.len().to_string(),
),
(
format!("pools.{:?}.users", pool_name),
pool.users
.iter()
.map(|(_username, user)| &user.username)
.cloned()
.collect::<Vec<String>>()
.join(", "),
),
]
})
.collect();
let mut static_settings = vec![
("host".to_string(), config.general.host.to_string()),
("port".to_string(), config.general.port.to_string()),
(
"pool_size".to_string(),
config.general.pool_size.to_string(),
),
(
"pool_mode".to_string(),
config.general.pool_mode.to_string(),
),
(
"connect_timeout".to_string(),
config.general.connect_timeout.to_string(),
@@ -220,42 +263,22 @@ impl From<&Config> for std::collections::HashMap<String, String> {
config.general.healthcheck_timeout.to_string(),
),
("ban_time".to_string(), config.general.ban_time.to_string()),
(
"default_role".to_string(),
config.query_router.default_role.to_string(),
),
(
"query_parser_enabled".to_string(),
config.query_router.query_parser_enabled.to_string(),
),
(
"primary_reads_enabled".to_string(),
config.query_router.primary_reads_enabled.to_string(),
),
(
"sharding_function".to_string(),
config.query_router.sharding_function.to_string(),
),
])
];
r.append(&mut static_settings);
return r.iter().cloned().collect();
}
}
impl Config {
/// Print current configuration.
pub fn show(&self) {
info!("Pool size: {}", self.general.pool_size);
info!("Pool mode: {}", self.general.pool_mode);
info!("Ban time: {}s", self.general.ban_time);
info!(
"Healthcheck timeout: {}ms",
self.general.healthcheck_timeout
);
info!("Connection timeout: {}ms", self.general.connect_timeout);
info!("Sharding function: {}", self.query_router.sharding_function);
info!("Primary reads: {}", self.query_router.primary_reads_enabled);
info!("Query router: {}", self.query_router.query_parser_enabled);
info!("Number of shards: {}", self.shards.len());
match self.general.tls_certificate.clone() {
Some(tls_certificate) => {
info!("TLS certificate: {}", tls_certificate);
@@ -274,6 +297,25 @@ impl Config {
info!("TLS support is disabled");
}
};
for (pool_name, pool_config) in &self.pools {
info!("--- Settings for pool {} ---", pool_name);
info!(
"Pool size from all users: {}",
pool_config
.users
.iter()
.map(|(_, user_cfg)| user_cfg.pool_size)
.sum::<u32>()
.to_string()
);
info!("Pool mode: {}", pool_config.pool_mode);
info!("Sharding function: {}", pool_config.sharding_function);
info!("Primary reads: {}", pool_config.primary_reads_enabled);
info!("Query router: {}", pool_config.query_parser_enabled);
info!("Number of shards: {}", pool_config.shards.len());
info!("Number of users: {}", pool_config.users.len());
}
}
}
@@ -311,88 +353,6 @@ pub async fn parse(path: &str) -> Result<(), Error> {
}
};
match config.query_router.sharding_function.as_ref() {
"pg_bigint_hash" => (),
"sha1" => (),
_ => {
error!(
"Supported sharding functions are: 'pg_bigint_hash', 'sha1', got: '{}'",
config.query_router.sharding_function
);
return Err(Error::BadConfig);
}
};
// Quick config sanity check.
for shard in &config.shards {
// We use addresses as unique identifiers,
// let's make sure they are unique in the config as well.
let mut dup_check = HashSet::new();
let mut primary_count = 0;
match shard.0.parse::<usize>() {
Ok(_) => (),
Err(_) => {
error!(
"Shard '{}' is not a valid number, shards must be numbered starting at 0",
shard.0
);
return Err(Error::BadConfig);
}
};
if shard.1.servers.len() == 0 {
error!("Shard {} has no servers configured", shard.0);
return Err(Error::BadConfig);
}
for server in &shard.1.servers {
dup_check.insert(server);
// Check that we define only zero or one primary.
match server.2.as_ref() {
"primary" => primary_count += 1,
_ => (),
};
// Check role spelling.
match server.2.as_ref() {
"primary" => (),
"replica" => (),
_ => {
error!(
"Shard {} server role must be either 'primary' or 'replica', got: '{}'",
shard.0, server.2
);
return Err(Error::BadConfig);
}
};
}
if primary_count > 1 {
error!("Shard {} has more than on primary configured", &shard.0);
return Err(Error::BadConfig);
}
if dup_check.len() != shard.1.servers.len() {
error!("Shard {} contains duplicate server configs", &shard.0);
return Err(Error::BadConfig);
}
}
match config.query_router.default_role.as_ref() {
"any" => (),
"primary" => (),
"replica" => (),
other => {
error!(
"Query router default_role must be 'primary', 'replica', or 'any', got: '{}'",
other
);
return Err(Error::BadConfig);
}
};
// Validate TLS!
match config.general.tls_certificate.clone() {
Some(tls_certificate) => {
@@ -424,6 +384,90 @@ pub async fn parse(path: &str) -> Result<(), Error> {
None => (),
};
for (pool_name, pool) in &config.pools {
match pool.sharding_function.as_ref() {
"pg_bigint_hash" => (),
"sha1" => (),
_ => {
error!(
"Supported sharding functions are: 'pg_bigint_hash', 'sha1', got: '{}' in pool {} settings",
pool.sharding_function,
pool_name
);
return Err(Error::BadConfig);
}
};
match pool.default_role.as_ref() {
"any" => (),
"primary" => (),
"replica" => (),
other => {
error!(
"Query router default_role must be 'primary', 'replica', or 'any', got: '{}'",
other
);
return Err(Error::BadConfig);
}
};
for shard in &pool.shards {
// We use addresses as unique identifiers,
// let's make sure they are unique in the config as well.
let mut dup_check = HashSet::new();
let mut primary_count = 0;
match shard.0.parse::<usize>() {
Ok(_) => (),
Err(_) => {
error!(
"Shard '{}' is not a valid number, shards must be numbered starting at 0",
shard.0
);
return Err(Error::BadConfig);
}
};
if shard.1.servers.len() == 0 {
error!("Shard {} has no servers configured", shard.0);
return Err(Error::BadConfig);
}
for server in &shard.1.servers {
dup_check.insert(server);
// Check that we define only zero or one primary.
match server.2.as_ref() {
"primary" => primary_count += 1,
_ => (),
};
// Check role spelling.
match server.2.as_ref() {
"primary" => (),
"replica" => (),
_ => {
error!(
"Shard {} server role must be either 'primary' or 'replica', got: '{}'",
shard.0, server.2
);
return Err(Error::BadConfig);
}
};
}
if primary_count > 1 {
error!("Shard {} has more than on primary configured", &shard.0);
return Err(Error::BadConfig);
}
if dup_check.len() != shard.1.servers.len() {
error!("Shard {} contains duplicate server configs", &shard.0);
return Err(Error::BadConfig);
}
}
}
config.path = path.to_string();
// Update the configuration globally.
@@ -434,7 +478,6 @@ pub async fn parse(path: &str) -> Result<(), Error> {
pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, Error> {
let old_config = get_config();
match parse(&old_config.path).await {
Ok(()) => (),
Err(err) => {
@@ -442,11 +485,10 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, E
return Err(Error::BadConfig);
}
};
let new_config = get_config();
if old_config.shards != new_config.shards || old_config.user != new_config.user {
info!("Sharding configuration changed, re-creating server pools");
if old_config.pools != new_config.pools {
info!("Pool configuration changed, re-creating server pools");
ConnectionPool::from_config(client_server_map).await?;
Ok(true)
} else if old_config != new_config {
@@ -463,11 +505,58 @@ mod test {
#[tokio::test]
async fn test_config() {
parse("pgcat.toml").await.unwrap();
assert_eq!(get_config().general.pool_size, 15);
assert_eq!(get_config().shards.len(), 3);
assert_eq!(get_config().shards["1"].servers[0].0, "127.0.0.1");
assert_eq!(get_config().shards["0"].servers[0].2, "primary");
assert_eq!(get_config().query_router.default_role, "any");
assert_eq!(get_config().path, "pgcat.toml".to_string());
assert_eq!(get_config().general.ban_time, 60);
assert_eq!(get_config().pools.len(), 2);
assert_eq!(get_config().pools["sharded"].shards.len(), 3);
assert_eq!(get_config().pools["simple_db"].shards.len(), 1);
assert_eq!(get_config().pools["sharded"].users.len(), 2);
assert_eq!(get_config().pools["simple_db"].users.len(), 1);
assert_eq!(
get_config().pools["sharded"].shards["0"].servers[0].0,
"127.0.0.1"
);
assert_eq!(
get_config().pools["sharded"].shards["1"].servers[0].2,
"primary"
);
assert_eq!(get_config().pools["sharded"].shards["1"].database, "shard1");
assert_eq!(
get_config().pools["sharded"].users["0"].username,
"sharding_user"
);
assert_eq!(
get_config().pools["sharded"].users["1"].password,
"other_user"
);
assert_eq!(get_config().pools["sharded"].users["1"].pool_size, 21);
assert_eq!(get_config().pools["sharded"].default_role, "any");
assert_eq!(
get_config().pools["simple_db"].shards["0"].servers[0].0,
"127.0.0.1"
);
assert_eq!(
get_config().pools["simple_db"].shards["0"].servers[0].1,
5432
);
assert_eq!(
get_config().pools["simple_db"].shards["0"].database,
"some_db"
);
assert_eq!(get_config().pools["simple_db"].default_role, "primary");
assert_eq!(
get_config().pools["simple_db"].users["0"].username,
"simple_user"
);
assert_eq!(
get_config().pools["simple_db"].users["0"].password,
"simple_user"
);
assert_eq!(get_config().pools["simple_db"].users["0"].pool_size, 5);
}
}

View File

@@ -66,10 +66,12 @@ use config::{get_config, reload_config};
use pool::{ClientServerMap, ConnectionPool};
use stats::{Collector, Reporter, REPORTER};
use crate::config::VERSION;
#[tokio::main(worker_threads = 4)]
async fn main() {
env_logger::init();
info!("Welcome to PgCat! Meow.");
info!("Welcome to PgCat! Meow. (Version {})", VERSION);
if !query_router::QueryRouter::setup() {
error!("Could not setup query router");

View File

@@ -10,19 +10,43 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use crate::config::{get_config, Address, Role, User};
use crate::config::{get_config, Address, Role, Shard, User};
use crate::errors::Error;
use crate::server::Server;
use crate::stats::{get_reporter, Reporter};
pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>;
pub type PoolMap = HashMap<(String, String), ConnectionPool>;
/// The connection pool, globally available.
/// This is atomic and safe and read-optimized.
/// The pool is recreated dynamically when the config is reloaded.
pub static POOL: Lazy<ArcSwap<ConnectionPool>> =
Lazy::new(|| ArcSwap::from_pointee(ConnectionPool::default()));
pub static POOLS: Lazy<ArcSwap<PoolMap>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default()));
#[derive(Clone, Debug)]
pub struct PoolSettings {
pub pool_mode: String,
pub shards: HashMap<String, Shard>,
pub user: User,
pub default_role: String,
pub query_parser_enabled: bool,
pub primary_reads_enabled: bool,
pub sharding_function: String,
}
impl Default for PoolSettings {
fn default() -> PoolSettings {
PoolSettings {
pool_mode: String::from("transaction"),
shards: HashMap::from([(String::from("1"), Shard::default())]),
user: User::default(),
default_role: String::from("any"),
query_parser_enabled: false,
primary_reads_enabled: true,
sharding_function: "pg_bigint_hash".to_string(),
}
}
}
/// The globally accessible connection pool.
#[derive(Clone, Debug, Default)]
@@ -46,107 +70,124 @@ pub struct ConnectionPool {
/// clients on startup. We pre-connect to all shards and replicas
/// on pool creation and save the K messages here.
server_info: BytesMut,
pub settings: PoolSettings,
}
impl ConnectionPool {
/// Construct the connection pool from the configuration.
pub async fn from_config(client_server_map: ClientServerMap) -> Result<(), Error> {
let reporter = get_reporter();
let config = get_config();
let mut new_pools = PoolMap::default();
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
let mut address_id = 0;
let mut shard_ids = config
.shards
.clone()
.into_keys()
.map(|x| x.to_string())
.collect::<Vec<String>>();
for (pool_name, pool_config) in &config.pools {
for (_user_index, user_info) in &pool_config.users {
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
let mut shard_ids = pool_config
.shards
.clone()
.into_keys()
.map(|x| x.to_string())
.collect::<Vec<String>>();
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
for shard_idx in shard_ids {
let shard = &config.shards[&shard_idx];
let mut pools = Vec::new();
let mut servers = Vec::new();
let mut replica_number = 0;
for shard_idx in shard_ids {
let shard = &pool_config.shards[&shard_idx];
let mut pools = Vec::new();
let mut servers = Vec::new();
let mut replica_number = 0;
for server in shard.servers.iter() {
let role = match server.2.as_ref() {
"primary" => Role::Primary,
"replica" => Role::Replica,
_ => {
error!("Config error: server role can be 'primary' or 'replica', have: '{}'. Defaulting to 'replica'.", server.2);
Role::Replica
for server in shard.servers.iter() {
let role = match server.2.as_ref() {
"primary" => Role::Primary,
"replica" => Role::Replica,
_ => {
error!("Config error: server role can be 'primary' or 'replica', have: '{}'. Defaulting to 'replica'.", server.2);
Role::Replica
}
};
let address = Address {
id: address_id,
database: pool_name.clone(),
host: server.0.clone(),
port: server.1.to_string(),
role: role,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
};
address_id += 1;
if role == Role::Replica {
replica_number += 1;
}
let manager = ServerPool::new(
address.clone(),
user_info.clone(),
&shard.database,
client_server_map.clone(),
get_reporter(),
);
let pool = Pool::builder()
.max_size(user_info.pool_size)
.connection_timeout(std::time::Duration::from_millis(
config.general.connect_timeout,
))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
pools.push(pool);
servers.push(address);
}
};
let address = Address {
id: address_id,
host: server.0.clone(),
port: server.1.to_string(),
role: role,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
};
address_id += 1;
if role == Role::Replica {
replica_number += 1;
shards.push(pools);
addresses.push(servers);
banlist.push(HashMap::new());
}
let manager = ServerPool::new(
address.clone(),
config.user.clone(),
&shard.database,
client_server_map.clone(),
reporter.clone(),
);
assert_eq!(shards.len(), addresses.len());
let pool = Pool::builder()
.max_size(config.general.pool_size)
.connection_timeout(std::time::Duration::from_millis(
config.general.connect_timeout,
))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
let mut pool = ConnectionPool {
databases: shards,
addresses: addresses,
banlist: Arc::new(RwLock::new(banlist)),
stats: get_reporter(),
server_info: BytesMut::new(),
settings: PoolSettings {
pool_mode: pool_config.pool_mode.clone(),
shards: pool_config.shards.clone(),
user: user_info.clone(),
default_role: pool_config.default_role.clone(),
query_parser_enabled: pool_config.query_parser_enabled.clone(),
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function.clone(),
},
};
pools.push(pool);
servers.push(address);
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
match pool.validate().await {
Ok(_) => (),
Err(err) => {
error!("Could not validate connection pool: {:?}", err);
return Err(err);
}
};
new_pools.insert((pool_name.clone(), user_info.username.clone()), pool);
}
shards.push(pools);
addresses.push(servers);
banlist.push(HashMap::new());
}
assert_eq!(shards.len(), addresses.len());
let mut pool = ConnectionPool {
databases: shards,
addresses: addresses,
banlist: Arc::new(RwLock::new(banlist)),
stats: reporter,
server_info: BytesMut::new(),
};
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
match pool.validate().await {
Ok(_) => (),
Err(err) => {
error!("Could not validate connection pool: {:?}", err);
return Err(err);
}
};
POOL.store(Arc::new(pool.clone()));
POOLS.store(Arc::new(new_pools.clone()));
Ok(())
}
@@ -474,7 +515,7 @@ impl ManageConnection for ServerPool {
info!(
"Creating a new connection to {:?} using user {:?}",
self.address.name(),
self.user.name
self.user.username
);
// Put a temporary process_id into the stats
@@ -517,6 +558,20 @@ impl ManageConnection for ServerPool {
}
/// Get the connection pool
pub fn get_pool() -> ConnectionPool {
(*(*POOL.load())).clone()
pub fn get_pool(db: String, user: String) -> Option<ConnectionPool> {
match get_all_pools().get(&(db, user)) {
Some(pool) => Some(pool.clone()),
None => None,
}
}
pub fn get_number_of_addresses() -> usize {
get_all_pools()
.iter()
.map(|(_, pool)| pool.databases())
.sum()
}
pub fn get_all_pools() -> HashMap<(String, String), ConnectionPool> {
return (*(*POOLS.load())).clone();
}

View File

@@ -8,7 +8,8 @@ use sqlparser::ast::Statement::{Query, StartTransaction};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
use crate::config::{get_config, Role};
use crate::config::Role;
use crate::pool::{ConnectionPool, PoolSettings};
use crate::sharding::{Sharder, ShardingFunction};
/// Regexes used to parse custom commands.
@@ -53,6 +54,8 @@ pub struct QueryRouter {
/// Include the primary into the replica pool for reads.
primary_reads_enabled: bool,
pool_settings: PoolSettings,
}
impl QueryRouter {
@@ -88,14 +91,13 @@ impl QueryRouter {
}
/// Create a new instance of the query router. Each client gets its own.
pub fn new() -> QueryRouter {
let config = get_config();
pub fn new(target_pool: ConnectionPool) -> QueryRouter {
QueryRouter {
active_shard: None,
active_role: None,
query_parser_enabled: config.query_router.query_parser_enabled,
primary_reads_enabled: config.query_router.primary_reads_enabled,
query_parser_enabled: target_pool.settings.query_parser_enabled,
primary_reads_enabled: target_pool.settings.primary_reads_enabled,
pool_settings: target_pool.settings,
}
}
@@ -130,15 +132,13 @@ impl QueryRouter {
return None;
}
let config = get_config();
let sharding_function = match config.query_router.sharding_function.as_ref() {
let sharding_function = match self.pool_settings.sharding_function.as_ref() {
"pg_bigint_hash" => ShardingFunction::PgBigintHash,
"sha1" => ShardingFunction::Sha1,
_ => unreachable!(),
};
let default_server_role = match config.query_router.default_role.as_ref() {
let default_server_role = match self.pool_settings.default_role.as_ref() {
"any" => None,
"primary" => Some(Role::Primary),
"replica" => Some(Role::Replica),
@@ -196,7 +196,7 @@ impl QueryRouter {
match command {
Command::SetShardingKey => {
let sharder = Sharder::new(config.shards.len(), sharding_function);
let sharder = Sharder::new(self.pool_settings.shards.len(), sharding_function);
let shard = sharder.shard(value.parse::<i64>().unwrap());
self.active_shard = Some(shard);
value = shard.to_string();
@@ -204,7 +204,7 @@ impl QueryRouter {
Command::SetShard => {
self.active_shard = match value.to_ascii_uppercase().as_ref() {
"ANY" => Some(rand::random::<usize>() % config.shards.len()),
"ANY" => Some(rand::random::<usize>() % self.pool_settings.shards.len()),
_ => Some(value.parse::<usize>().unwrap()),
};
}
@@ -233,7 +233,7 @@ impl QueryRouter {
"default" => {
self.active_role = default_server_role;
self.query_parser_enabled = config.query_router.query_parser_enabled;
self.query_parser_enabled = self.query_parser_enabled;
self.active_role
}
@@ -250,7 +250,7 @@ impl QueryRouter {
self.primary_reads_enabled = false;
} else if value == "default" {
debug!("Setting primary reads to default");
self.primary_reads_enabled = config.query_router.primary_reads_enabled;
self.primary_reads_enabled = self.pool_settings.primary_reads_enabled;
}
}
@@ -370,7 +370,7 @@ mod test {
#[test]
fn test_defaults() {
QueryRouter::setup();
let qr = QueryRouter::new();
let qr = QueryRouter::new(ConnectionPool::default());
assert_eq!(qr.role(), None);
}
@@ -378,7 +378,7 @@ mod test {
#[test]
fn test_infer_role_replica() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
let mut qr = QueryRouter::new(ConnectionPool::default());
assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None);
assert_eq!(qr.query_parser_enabled(), true);
@@ -402,7 +402,7 @@ mod test {
#[test]
fn test_infer_role_primary() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
let mut qr = QueryRouter::new(ConnectionPool::default());
let queries = vec![
simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"),
@@ -421,7 +421,7 @@ mod test {
#[test]
fn test_infer_role_primary_reads_enabled() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
let mut qr = QueryRouter::new(ConnectionPool::default());
let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None);
@@ -432,7 +432,7 @@ mod test {
#[test]
fn test_infer_role_parse_prepared() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
let mut qr = QueryRouter::new(ConnectionPool::default());
qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'"));
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);
@@ -523,15 +523,15 @@ mod test {
#[test]
fn test_try_execute_command() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
let mut qr = QueryRouter::new(ConnectionPool::default());
// SetShardingKey
let query = simple_query("SET SHARDING KEY TO 13");
assert_eq!(
qr.try_execute_command(query),
Some((Command::SetShardingKey, String::from("1")))
Some((Command::SetShardingKey, String::from("0")))
);
assert_eq!(qr.shard(), 1);
assert_eq!(qr.shard(), 0);
// SetShard
let query = simple_query("SET SHARD TO '1'");
@@ -600,7 +600,7 @@ mod test {
#[test]
fn test_enable_query_parser() {
QueryRouter::setup();
let mut qr = QueryRouter::new();
let mut qr = QueryRouter::new(ConnectionPool::default());
let query = simple_query("SET SERVER ROLE TO 'auto'");
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);

View File

@@ -82,7 +82,7 @@ impl Server {
trace!("Sending StartupMessage");
// StartupMessage
startup(&mut stream, &user.name, database).await?;
startup(&mut stream, &user.username, database).await?;
let mut server_info = BytesMut::new();
let mut process_id: i32 = 0;
@@ -127,7 +127,7 @@ impl Server {
Err(_) => return Err(Error::SocketError),
};
md5_password(&mut stream, &user.name, &user.password, &salt[..])
md5_password(&mut stream, &user.username, &user.password, &salt[..])
.await?;
}

View File

@@ -6,7 +6,7 @@ use parking_lot::Mutex;
use std::collections::HashMap;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use crate::pool::get_pool;
use crate::pool::get_number_of_addresses;
pub static REPORTER: Lazy<ArcSwap<Reporter>> =
Lazy::new(|| ArcSwap::from_pointee(Reporter::default()));
@@ -331,8 +331,8 @@ impl Collector {
tokio::time::interval(tokio::time::Duration::from_millis(STAT_PERIOD / 15));
loop {
interval.tick().await;
let addresses = get_pool().databases();
for address_id in 0..addresses {
let address_count = get_number_of_addresses();
for address_id in 0..address_count {
let _ = tx.try_send(Event {
name: EventName::UpdateStats,
value: 0,
@@ -349,8 +349,8 @@ impl Collector {
tokio::time::interval(tokio::time::Duration::from_millis(STAT_PERIOD));
loop {
interval.tick().await;
let addresses = get_pool().databases();
for address_id in 0..addresses {
let address_count = get_number_of_addresses();
for address_id in 0..address_count {
let _ = tx.try_send(Event {
name: EventName::UpdateAverages,
value: 0,

View File

@@ -15,7 +15,7 @@ ActiveRecord::Base.establish_connection(
port: 6432,
username: 'sharding_user',
password: 'sharding_user',
database: 'rails_dev',
database: 'sharded_db',
application_name: 'testing_pgcat',
prepared_statements: false, # Transaction mode
advisory_locks: false # Same
@@ -117,7 +117,7 @@ end
# Test evil clients
def poorly_behaved_client
conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/rails_dev?application_name=testing_pgcat")
conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db?application_name=testing_pgcat")
conn.async_exec 'BEGIN'
conn.async_exec 'SELECT 1'