Add support for multi-database / multi-user pools (#96)

* Add support for multi-database / multi-user pools

* Nothing

* cargo fmt

* CI

* remove test users

* rename pool

* Update tests to use admin user/pass

* more fixes

* Revert bad change

* Use PGDATABASE env var

* send server info in case of admin
This commit is contained in:
Mostafa Abdelraouf
2022-07-27 21:47:55 -05:00
committed by GitHub
parent c5be5565a5
commit 2ae4b438e3
14 changed files with 700 additions and 503 deletions

View File

@@ -5,21 +5,12 @@
# #
# General pooler settings # General pooler settings
[general] [general]
# What IP to run on, 0.0.0.0 means accessible from everywhere. # What IP to run on, 0.0.0.0 means accessible from everywhere.
host = "0.0.0.0" host = "0.0.0.0"
# Port to run on, same as PgBouncer used in this example. # Port to run on, same as PgBouncer used in this example.
port = 6432 port = 6432
# How many connections to allocate per server.
pool_size = 15
# Pool mode (see PgBouncer docs for more).
# session: one server connection per connected client
# transaction: one server connection per client transaction
pool_mode = "transaction"
# How long to wait before aborting a server connection (ms). # How long to wait before aborting a server connection (ms).
connect_timeout = 100 connect_timeout = 100
@@ -29,56 +20,27 @@ healthcheck_timeout = 100
# For how long to ban a server if it fails a health check (seconds). # For how long to ban a server if it fails a health check (seconds).
ban_time = 60 # Seconds ban_time = 60 # Seconds
# # Reload config automatically if it changes.
autoreload = true autoreload = true
# TLS
tls_certificate = ".circleci/server.cert" tls_certificate = ".circleci/server.cert"
tls_private_key = ".circleci/server.key" tls_private_key = ".circleci/server.key"
# # Credentials to access the virtual administrative database (pgbouncer or pgcat)
# User to use for authentication against the server. # Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
[user] admin_username = "admin_user"
name = "sharding_user" admin_password = "admin_pass"
password = "sharding_user"
# pool
# # configs are structured as pool.<pool_name>
# Shards in the cluster # the pool_name is what clients use as database name when connecting
[shards] # For the example below a client can connect using "postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded"
[pools.sharded_db]
# Shard 0 # Pool mode (see PgBouncer docs for more).
[shards.0] # session: one server connection per connected client
# transaction: one server connection per client transaction
# [ host, port, role ] pool_mode = "transaction"
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5433, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
]
# Database name (e.g. "postgres")
database = "shard0"
[shards.1]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5433, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
]
database = "shard1"
[shards.2]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5433, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
]
database = "shard2"
# Settings for our query routing layer.
[query_router]
# If the client doesn't specify, route traffic to # If the client doesn't specify, route traffic to
# this role by default. # this role by default.
@@ -88,7 +50,6 @@ database = "shard2"
# primary: all queries go to the primary unless otherwise specified. # primary: all queries go to the primary unless otherwise specified.
default_role = "any" default_role = "any"
# Query parser. If enabled, we'll attempt to parse # Query parser. If enabled, we'll attempt to parse
# every incoming query to determine if it's a read or a write. # every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write, # If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
@@ -109,3 +70,36 @@ primary_reads_enabled = true
# sha1: A hashing function based on SHA1 # sha1: A hashing function based on SHA1
# #
sharding_function = "pg_bigint_hash" sharding_function = "pg_bigint_hash"
# Credentials for users that may connect to this cluster
[pools.sharded_db.users.0]
username = "sharding_user"
password = "sharding_user"
# Maximum number of server connections that can be established for this user
# The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users.
pool_size = 9
# Shard 0
[pools.sharded_db.shards.0]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ]
]
# Database name (e.g. "postgres")
database = "shard0"
[pools.sharded_db.shards.1]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard1"
[pools.sharded_db.shards.2]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard2"

View File

@@ -32,6 +32,7 @@ toxiproxy-cli create -l 127.0.0.1:5433 -u 127.0.0.1:5432 postgres_replica
start_pgcat "info" start_pgcat "info"
export PGPASSWORD=sharding_user export PGPASSWORD=sharding_user
export PGDATABASE=sharded_db
# pgbench test # pgbench test
pgbench -U sharding_user -i -h 127.0.0.1 -p 6432 pgbench -U sharding_user -i -h 127.0.0.1 -p 6432
@@ -47,7 +48,7 @@ sleep 1
killall psql -s SIGINT killall psql -s SIGINT
# Reload pool (closing unused server connections) # Reload pool (closing unused server connections)
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' PGPASSWORD=admin_pass psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD'
(psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(50)' || true) & (psql -U sharding_user -h 127.0.0.1 -p 6432 -c 'SELECT pg_sleep(50)' || true) &
sleep 1 sleep 1
@@ -72,15 +73,17 @@ cd tests/ruby && \
cd ../.. cd ../..
# Admin tests # Admin tests
psql -U sharding_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null export PGPASSWORD=admin_pass
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null psql -U admin_user -e -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW STATS' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'RELOAD' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW CONFIG' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW LISTS' > /dev/null psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW DATABASES' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW POOLS' > /dev/null psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW LISTS' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW VERSION' > /dev/null psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW POOLS' > /dev/null
psql -U sharding_user -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev/null # will ignore psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c 'SHOW VERSION' > /dev/null
(! psql -U sharding_user -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null) psql -U admin_user -h 127.0.0.1 -p 6432 -d pgbouncer -c "SET client_encoding TO 'utf8'" > /dev/null # will ignore
(! psql -U admin_user -e -h 127.0.0.1 -p 6432 -d random_db -c 'SHOW STATS' > /dev/null)
export PGPASSWORD=sharding_user
# Start PgCat in debug to demonstrate failover better # Start PgCat in debug to demonstrate failover better
start_pgcat "trace" start_pgcat "trace"

2
Cargo.lock generated
View File

@@ -395,7 +395,7 @@ dependencies = [
[[package]] [[package]]
name = "pgcat" name = "pgcat"
version = "0.4.0-beta1" version = "0.6.0-alpha1"
dependencies = [ dependencies = [
"arc-swap", "arc-swap",
"async-trait", "async-trait",

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "pgcat" name = "pgcat"
version = "0.4.0-beta1" version = "0.6.0-alpha1"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

View File

@@ -5,21 +5,12 @@
# #
# General pooler settings # General pooler settings
[general] [general]
# What IP to run on, 0.0.0.0 means accessible from everywhere. # What IP to run on, 0.0.0.0 means accessible from everywhere.
host = "0.0.0.0" host = "0.0.0.0"
# Port to run on, same as PgBouncer used in this example. # Port to run on, same as PgBouncer used in this example.
port = 6432 port = 6432
# How many connections to allocate per server.
pool_size = 15
# Pool mode (see PgBouncer docs for more).
# session: one server connection per connected client
# transaction: one server connection per client transaction
pool_mode = "transaction"
# How long to wait before aborting a server connection (ms). # How long to wait before aborting a server connection (ms).
connect_timeout = 5000 connect_timeout = 5000
@@ -27,7 +18,7 @@ connect_timeout = 5000
healthcheck_timeout = 1000 healthcheck_timeout = 1000
# For how long to ban a server if it fails a health check (seconds). # For how long to ban a server if it fails a health check (seconds).
ban_time = 60 # Seconds ban_time = 60 # seconds
# Reload config automatically if it changes. # Reload config automatically if it changes.
autoreload = false autoreload = false
@@ -36,50 +27,20 @@ autoreload = false
# tls_certificate = "server.cert" # tls_certificate = "server.cert"
# tls_private_key = "server.key" # tls_private_key = "server.key"
# # Credentials to access the virtual administrative database (pgbouncer or pgcat)
# User to use for authentication against the server. # Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
[user] admin_username = "user"
name = "sharding_user" admin_password = "pass"
password = "sharding_user"
# pool
# # configs are structured as pool.<pool_name>
# Shards in the cluster # the pool_name is what clients use as database name when connecting
[shards] # For the example below a client can connect using "postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded"
[pools.sharded]
# Shard 0 # Pool mode (see PgBouncer docs for more).
[shards.0] # session: one server connection per connected client
# transaction: one server connection per client transaction
# [ host, port, role ] pool_mode = "transaction"
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
]
# Database name (e.g. "postgres")
database = "shard0"
[shards.1]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
]
database = "shard1"
[shards.2]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
# [ "127.0.1.1", 5432, "replica" ],
]
database = "shard2"
# Settings for our query routing layer.
[query_router]
# If the client doesn't specify, route traffic to # If the client doesn't specify, route traffic to
# this role by default. # this role by default.
@@ -89,7 +50,6 @@ database = "shard2"
# primary: all queries go to the primary unless otherwise specified. # primary: all queries go to the primary unless otherwise specified.
default_role = "any" default_role = "any"
# Query parser. If enabled, we'll attempt to parse # Query parser. If enabled, we'll attempt to parse
# every incoming query to determine if it's a read or a write. # every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write, # If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
@@ -110,3 +70,61 @@ primary_reads_enabled = true
# sha1: A hashing function based on SHA1 # sha1: A hashing function based on SHA1
# #
sharding_function = "pg_bigint_hash" sharding_function = "pg_bigint_hash"
# Credentials for users that may connect to this cluster
[pools.sharded.users.0]
username = "sharding_user"
password = "sharding_user"
# Maximum number of server connections that can be established for this user
# The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users.
pool_size = 9
[pools.sharded.users.1]
username = "other_user"
password = "other_user"
pool_size = 21
# Shard 0
[pools.sharded.shards.0]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ]
]
# Database name (e.g. "postgres")
database = "shard0"
[pools.sharded.shards.1]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard1"
[pools.sharded.shards.2]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard2"
[pools.simple_db]
pool_mode = "session"
default_role = "primary"
query_parser_enabled = true
primary_reads_enabled = true
sharding_function = "pg_bigint_hash"
[pools.simple_db.users.0]
username = "simple_user"
password = "simple_user"
pool_size = 5
[pools.simple_db.shards.0]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ]
]
database = "some_db"

View File

@@ -3,10 +3,10 @@ use bytes::{Buf, BufMut, BytesMut};
use log::{info, trace}; use log::{info, trace};
use std::collections::HashMap; use std::collections::HashMap;
use crate::config::{get_config, reload_config}; use crate::config::{get_config, reload_config, VERSION};
use crate::errors::Error; use crate::errors::Error;
use crate::messages::*; use crate::messages::*;
use crate::pool::ConnectionPool; use crate::pool::get_all_pools;
use crate::stats::get_stats; use crate::stats::get_stats;
use crate::ClientServerMap; use crate::ClientServerMap;
@@ -14,7 +14,6 @@ use crate::ClientServerMap;
pub async fn handle_admin<T>( pub async fn handle_admin<T>(
stream: &mut T, stream: &mut T,
mut query: BytesMut, mut query: BytesMut,
pool: ConnectionPool,
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
) -> Result<(), Error> ) -> Result<(), Error>
where where
@@ -35,7 +34,7 @@ where
if query.starts_with("SHOW STATS") { if query.starts_with("SHOW STATS") {
trace!("SHOW STATS"); trace!("SHOW STATS");
show_stats(stream, &pool).await show_stats(stream).await
} else if query.starts_with("RELOAD") { } else if query.starts_with("RELOAD") {
trace!("RELOAD"); trace!("RELOAD");
reload(stream, client_server_map).await reload(stream, client_server_map).await
@@ -44,13 +43,13 @@ where
show_config(stream).await show_config(stream).await
} else if query.starts_with("SHOW DATABASES") { } else if query.starts_with("SHOW DATABASES") {
trace!("SHOW DATABASES"); trace!("SHOW DATABASES");
show_databases(stream, &pool).await show_databases(stream).await
} else if query.starts_with("SHOW POOLS") { } else if query.starts_with("SHOW POOLS") {
trace!("SHOW POOLS"); trace!("SHOW POOLS");
show_pools(stream, &pool).await show_pools(stream).await
} else if query.starts_with("SHOW LISTS") { } else if query.starts_with("SHOW LISTS") {
trace!("SHOW LISTS"); trace!("SHOW LISTS");
show_lists(stream, &pool).await show_lists(stream).await
} else if query.starts_with("SHOW VERSION") { } else if query.starts_with("SHOW VERSION") {
trace!("SHOW VERSION"); trace!("SHOW VERSION");
show_version(stream).await show_version(stream).await
@@ -63,7 +62,7 @@ where
} }
/// Column-oriented statistics. /// Column-oriented statistics.
async fn show_lists<T>(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error> async fn show_lists<T>(stream: &mut T) -> Result<(), Error>
where where
T: tokio::io::AsyncWrite + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
@@ -71,17 +70,20 @@ where
let columns = vec![("list", DataType::Text), ("items", DataType::Int4)]; let columns = vec![("list", DataType::Text), ("items", DataType::Int4)];
let mut users = 1;
let mut databases = 1;
for (_, pool) in get_all_pools() {
databases += pool.databases();
users += 1; // One user per pool
}
let mut res = BytesMut::new(); let mut res = BytesMut::new();
res.put(row_description(&columns)); res.put(row_description(&columns));
res.put(data_row(&vec![ res.put(data_row(&vec![
"databases".to_string(), "databases".to_string(),
(pool.databases() + 1).to_string(), // see comment below databases.to_string(),
])); ]));
res.put(data_row(&vec!["users".to_string(), "1".to_string()])); res.put(data_row(&vec!["users".to_string(), users.to_string()]));
res.put(data_row(&vec![ res.put(data_row(&vec!["pools".to_string(), databases.to_string()]));
"pools".to_string(),
(pool.databases() + 1).to_string(), // +1 for the pgbouncer admin db pool which isn't real
])); // but admin tools that work with pgbouncer want this
res.put(data_row(&vec![ res.put(data_row(&vec![
"free_clients".to_string(), "free_clients".to_string(),
stats stats
@@ -140,7 +142,7 @@ where
let mut res = BytesMut::new(); let mut res = BytesMut::new();
res.put(row_description(&vec![("version", DataType::Text)])); res.put(row_description(&vec![("version", DataType::Text)]));
res.put(data_row(&vec!["PgCat 0.1.0".to_string()])); res.put(data_row(&vec![format!("PgCat {}", VERSION).to_string()]));
res.put(command_complete("SHOW")); res.put(command_complete("SHOW"));
res.put_u8(b'Z'); res.put_u8(b'Z');
@@ -151,12 +153,11 @@ where
} }
/// Show utilization of connection pools for each shard and replicas. /// Show utilization of connection pools for each shard and replicas.
async fn show_pools<T>(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error> async fn show_pools<T>(stream: &mut T) -> Result<(), Error>
where where
T: tokio::io::AsyncWrite + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let stats = get_stats(); let stats = get_stats();
let config = get_config();
let columns = vec![ let columns = vec![
("database", DataType::Text), ("database", DataType::Text),
@@ -176,24 +177,26 @@ where
let mut res = BytesMut::new(); let mut res = BytesMut::new();
res.put(row_description(&columns)); res.put(row_description(&columns));
for (_, pool) in get_all_pools() {
let pool_config = &pool.settings;
for shard in 0..pool.shards() {
for server in 0..pool.servers(shard) {
let address = pool.address(shard, server);
let stats = match stats.get(&address.id) {
Some(stats) => stats.clone(),
None => HashMap::new(),
};
for shard in 0..pool.shards() { let mut row = vec![address.name(), pool_config.user.username.clone()];
for server in 0..pool.servers(shard) {
let address = pool.address(shard, server);
let stats = match stats.get(&address.id) {
Some(stats) => stats.clone(),
None => HashMap::new(),
};
let mut row = vec![address.name(), config.user.name.clone()]; for column in &columns[2..columns.len() - 1] {
let value = stats.get(column.0).unwrap_or(&0).to_string();
row.push(value);
}
for column in &columns[2..columns.len() - 1] { row.push(pool_config.pool_mode.to_string());
let value = stats.get(column.0).unwrap_or(&0).to_string(); res.put(data_row(&row));
row.push(value);
} }
row.push(config.general.pool_mode.to_string());
res.put(data_row(&row));
} }
} }
@@ -208,12 +211,10 @@ where
} }
/// Show shards and replicas. /// Show shards and replicas.
async fn show_databases<T>(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error> async fn show_databases<T>(stream: &mut T) -> Result<(), Error>
where where
T: tokio::io::AsyncWrite + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
let config = get_config();
// Columns // Columns
let columns = vec![ let columns = vec![
("name", DataType::Text), ("name", DataType::Text),
@@ -235,31 +236,33 @@ where
res.put(row_description(&columns)); res.put(row_description(&columns));
for shard in 0..pool.shards() { for (_, pool) in get_all_pools() {
let database_name = &config.shards[&shard.to_string()].database; let pool_config = pool.settings.clone();
for shard in 0..pool.shards() {
let database_name = &pool_config.shards[&shard.to_string()].database;
for server in 0..pool.servers(shard) { for server in 0..pool.servers(shard) {
let address = pool.address(shard, server); let address = pool.address(shard, server);
let pool_state = pool.pool_state(shard, server); let pool_state = pool.pool_state(shard, server);
res.put(data_row(&vec![ res.put(data_row(&vec![
address.name(), // name address.name(), // name
address.host.to_string(), // host address.host.to_string(), // host
address.port.to_string(), // port address.port.to_string(), // port
database_name.to_string(), // database database_name.to_string(), // database
config.user.name.to_string(), // force_user pool_config.user.username.to_string(), // force_user
config.general.pool_size.to_string(), // pool_size pool_config.user.pool_size.to_string(), // pool_size
"0".to_string(), // min_pool_size "0".to_string(), // min_pool_size
"0".to_string(), // reserve_pool "0".to_string(), // reserve_pool
config.general.pool_mode.to_string(), // pool_mode pool_config.pool_mode.to_string(), // pool_mode
config.general.pool_size.to_string(), // max_connections pool_config.user.pool_size.to_string(), // max_connections
pool_state.connections.to_string(), // current_connections pool_state.connections.to_string(), // current_connections
"0".to_string(), // paused "0".to_string(), // paused
"0".to_string(), // disabled "0".to_string(), // disabled
])); ]));
}
} }
} }
res.put(command_complete("SHOW")); res.put(command_complete("SHOW"));
// ReadyForQuery // ReadyForQuery
@@ -349,7 +352,7 @@ where
} }
/// Show shard and replicas statistics. /// Show shard and replicas statistics.
async fn show_stats<T>(stream: &mut T, pool: &ConnectionPool) -> Result<(), Error> async fn show_stats<T>(stream: &mut T) -> Result<(), Error>
where where
T: tokio::io::AsyncWrite + std::marker::Unpin, T: tokio::io::AsyncWrite + std::marker::Unpin,
{ {
@@ -375,21 +378,23 @@ where
let mut res = BytesMut::new(); let mut res = BytesMut::new();
res.put(row_description(&columns)); res.put(row_description(&columns));
for shard in 0..pool.shards() { for (_, pool) in get_all_pools() {
for server in 0..pool.servers(shard) { for shard in 0..pool.shards() {
let address = pool.address(shard, server); for server in 0..pool.servers(shard) {
let stats = match stats.get(&address.id) { let address = pool.address(shard, server);
Some(stats) => stats.clone(), let stats = match stats.get(&address.id) {
None => HashMap::new(), Some(stats) => stats.clone(),
}; None => HashMap::new(),
};
let mut row = vec![address.name()]; let mut row = vec![address.name()];
for column in &columns[1..] { for column in &columns[1..] {
row.push(stats.get(column.0).unwrap_or(&0).to_string()); row.push(stats.get(column.0).unwrap_or(&0).to_string());
}
res.put(data_row(&row));
} }
res.put(data_row(&row));
} }
} }

View File

@@ -10,7 +10,7 @@ use crate::config::get_config;
use crate::constants::*; use crate::constants::*;
use crate::errors::Error; use crate::errors::Error;
use crate::messages::*; use crate::messages::*;
use crate::pool::{get_pool, ClientServerMap}; use crate::pool::{get_pool, ClientServerMap, ConnectionPool};
use crate::query_router::{Command, QueryRouter}; use crate::query_router::{Command, QueryRouter};
use crate::server::Server; use crate::server::Server;
use crate::stats::{get_reporter, Reporter}; use crate::stats::{get_reporter, Reporter};
@@ -71,6 +71,8 @@ pub struct Client<S, T> {
/// Last server process id we talked to. /// Last server process id we talked to.
last_server_id: Option<i32>, last_server_id: Option<i32>,
target_pool: ConnectionPool,
} }
/// Client entrypoint. /// Client entrypoint.
@@ -258,11 +260,25 @@ where
client_server_map: ClientServerMap, client_server_map: ClientServerMap,
) -> Result<Client<S, T>, Error> { ) -> Result<Client<S, T>, Error> {
let config = get_config(); let config = get_config();
let transaction_mode = config.general.pool_mode == "transaction";
let stats = get_reporter(); let stats = get_reporter();
trace!("Got StartupMessage"); trace!("Got StartupMessage");
let parameters = parse_startup(bytes.clone())?; let parameters = parse_startup(bytes.clone())?;
let database = match parameters.get("database") {
Some(db) => db,
None => return Err(Error::ClientError),
};
let user = match parameters.get("user") {
Some(user) => user,
None => return Err(Error::ClientError),
};
let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == &database)
.count()
== 1;
// Generate random backend ID and secret key // Generate random backend ID and secret key
let process_id: i32 = rand::random(); let process_id: i32 = rand::random();
@@ -295,33 +311,57 @@ where
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError),
}; };
// Compare server and client hashes. let mut target_pool: ConnectionPool = ConnectionPool::default();
let password_hash = md5_hash_password(&config.user.name, &config.user.password, &salt); let mut transaction_mode = false;
if password_hash != password_response { if admin {
debug!("Password authentication failed"); let correct_user = config.general.admin_username.as_str();
wrong_password(&mut write, &config.user.name).await?; let correct_password = config.general.admin_password.as_str();
return Err(Error::ClientError);
// Compare server and client hashes.
let password_hash = md5_hash_password(correct_user, correct_password, &salt);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, user).await?;
return Err(Error::ClientError);
}
} else {
target_pool = match get_pool(database.clone(), user.clone()) {
Some(pool) => pool,
None => {
error_response(
&mut write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
database, user
),
)
.await?;
return Err(Error::ClientError);
}
};
transaction_mode = target_pool.settings.pool_mode == "transaction";
// Compare server and client hashes.
let correct_password = target_pool.settings.user.password.as_str();
let password_hash = md5_hash_password(user, correct_password, &salt);
if password_hash != password_response {
debug!("Password authentication failed");
wrong_password(&mut write, user).await?;
return Err(Error::ClientError);
}
} }
debug!("Password authentication successful"); debug!("Password authentication successful");
auth_ok(&mut write).await?; auth_ok(&mut write).await?;
write_all(&mut write, get_pool().server_info()).await?; write_all(&mut write, target_pool.server_info()).await?;
backend_key_data(&mut write, process_id, secret_key).await?; backend_key_data(&mut write, process_id, secret_key).await?;
ready_for_query(&mut write).await?; ready_for_query(&mut write).await?;
trace!("Startup OK"); trace!("Startup OK");
let database = parameters
.get("database")
.unwrap_or(parameters.get("user").unwrap());
let admin = ["pgcat", "pgbouncer"]
.iter()
.filter(|db| *db == &database)
.count()
== 1;
// Split the read and write streams // Split the read and write streams
// so we can control buffering. // so we can control buffering.
@@ -335,11 +375,12 @@ where
process_id: process_id, process_id: process_id,
secret_key: secret_key, secret_key: secret_key,
client_server_map: client_server_map, client_server_map: client_server_map,
parameters: parameters, parameters: parameters.clone(),
stats: stats, stats: stats,
admin: admin, admin: admin,
last_address_id: None, last_address_id: None,
last_server_id: None, last_server_id: None,
target_pool: target_pool,
}); });
} }
@@ -353,26 +394,22 @@ where
) -> Result<Client<S, T>, Error> { ) -> Result<Client<S, T>, Error> {
let process_id = bytes.get_i32(); let process_id = bytes.get_i32();
let secret_key = bytes.get_i32(); let secret_key = bytes.get_i32();
let config = get_config();
let transaction_mode = config.general.pool_mode == "transaction";
let stats = get_reporter();
return Ok(Client { return Ok(Client {
read: BufReader::new(read), read: BufReader::new(read),
write: write, write: write,
addr, addr,
buffer: BytesMut::with_capacity(8196), buffer: BytesMut::with_capacity(8196),
cancel_mode: true, cancel_mode: true,
transaction_mode: transaction_mode, transaction_mode: false,
process_id: process_id, process_id: process_id,
secret_key: secret_key, secret_key: secret_key,
client_server_map: client_server_map, client_server_map: client_server_map,
parameters: HashMap::new(), parameters: HashMap::new(),
stats: stats, stats: get_reporter(),
admin: false, admin: false,
last_address_id: None, last_address_id: None,
last_server_id: None, last_server_id: None,
target_pool: ConnectionPool::default(),
}); });
} }
@@ -410,7 +447,7 @@ where
// The query router determines where the query is going to go, // The query router determines where the query is going to go,
// e.g. primary, replica, which shard. // e.g. primary, replica, which shard.
let mut query_router = QueryRouter::new(); let mut query_router = QueryRouter::new(self.target_pool.clone());
let mut round_robin = 0; let mut round_robin = 0;
// Our custom protocol loop. // Our custom protocol loop.
@@ -432,7 +469,7 @@ where
// Get a pool instance referenced by the most up-to-date // Get a pool instance referenced by the most up-to-date
// pointer. This ensures we always read the latest config // pointer. This ensures we always read the latest config
// when starting a query. // when starting a query.
let mut pool = get_pool(); let mut pool = self.target_pool.clone();
// Avoid taking a server if the client just wants to disconnect. // Avoid taking a server if the client just wants to disconnect.
if message[0] as char == 'X' { if message[0] as char == 'X' {
@@ -443,13 +480,7 @@ where
// Handle admin database queries. // Handle admin database queries.
if self.admin { if self.admin {
debug!("Handling admin command"); debug!("Handling admin command");
handle_admin( handle_admin(&mut self.write, message, self.client_server_map.clone()).await?;
&mut self.write,
message,
pool.clone(),
self.client_server_map.clone(),
)
.await?;
continue; continue;
} }

View File

@@ -4,6 +4,7 @@ use log::{error, info};
use once_cell::sync::Lazy; use once_cell::sync::Lazy;
use serde_derive::Deserialize; use serde_derive::Deserialize;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::hash::Hash;
use std::path::Path; use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use tokio::fs::File; use tokio::fs::File;
@@ -14,6 +15,8 @@ use crate::errors::Error;
use crate::tls::{load_certs, load_keys}; use crate::tls::{load_certs, load_keys};
use crate::{ClientServerMap, ConnectionPool}; use crate::{ClientServerMap, ConnectionPool};
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Globally available configuration. /// Globally available configuration.
static CONFIG: Lazy<ArcSwap<Config>> = Lazy::new(|| ArcSwap::from_pointee(Config::default())); static CONFIG: Lazy<ArcSwap<Config>> = Lazy::new(|| ArcSwap::from_pointee(Config::default()));
@@ -58,6 +61,7 @@ pub struct Address {
pub host: String, pub host: String,
pub port: String, pub port: String,
pub shard: usize, pub shard: usize,
pub database: String,
pub role: Role, pub role: Role,
pub replica_number: usize, pub replica_number: usize,
} }
@@ -70,6 +74,7 @@ impl Default for Address {
port: String::from("5432"), port: String::from("5432"),
shard: 0, shard: 0,
replica_number: 0, replica_number: 0,
database: String::from("database"),
role: Role::Replica, role: Role::Replica,
} }
} }
@@ -79,9 +84,12 @@ impl Address {
/// Address name (aka database) used in `SHOW STATS`, `SHOW DATABASES`, and `SHOW POOLS`. /// Address name (aka database) used in `SHOW STATS`, `SHOW DATABASES`, and `SHOW POOLS`.
pub fn name(&self) -> String { pub fn name(&self) -> String {
match self.role { match self.role {
Role::Primary => format!("shard_{}_primary", self.shard), Role::Primary => format!("{}_shard_{}_primary", self.database, self.shard),
Role::Replica => format!("shard_{}_replica_{}", self.shard, self.replica_number), Role::Replica => format!(
"{}_shard_{}_replica_{}",
self.database, self.shard, self.replica_number
),
} }
} }
} }
@@ -89,15 +97,17 @@ impl Address {
/// PostgreSQL user. /// PostgreSQL user.
#[derive(Clone, PartialEq, Hash, std::cmp::Eq, Deserialize, Debug)] #[derive(Clone, PartialEq, Hash, std::cmp::Eq, Deserialize, Debug)]
pub struct User { pub struct User {
pub name: String, pub username: String,
pub password: String, pub password: String,
pub pool_size: u32,
} }
impl Default for User { impl Default for User {
fn default() -> User { fn default() -> User {
User { User {
name: String::from("postgres"), username: String::from("postgres"),
password: String::new(), password: String::new(),
pool_size: 15,
} }
} }
} }
@@ -107,14 +117,14 @@ impl Default for User {
pub struct General { pub struct General {
pub host: String, pub host: String,
pub port: i16, pub port: i16,
pub pool_size: u32,
pub pool_mode: String,
pub connect_timeout: u64, pub connect_timeout: u64,
pub healthcheck_timeout: u64, pub healthcheck_timeout: u64,
pub ban_time: i64, pub ban_time: i64,
pub autoreload: bool, pub autoreload: bool,
pub tls_certificate: Option<String>, pub tls_certificate: Option<String>,
pub tls_private_key: Option<String>, pub tls_private_key: Option<String>,
pub admin_username: String,
pub admin_password: String,
} }
impl Default for General { impl Default for General {
@@ -122,14 +132,37 @@ impl Default for General {
General { General {
host: String::from("localhost"), host: String::from("localhost"),
port: 5432, port: 5432,
pool_size: 15,
pool_mode: String::from("transaction"),
connect_timeout: 5000, connect_timeout: 5000,
healthcheck_timeout: 1000, healthcheck_timeout: 1000,
ban_time: 60, ban_time: 60,
autoreload: false, autoreload: false,
tls_certificate: None, tls_certificate: None,
tls_private_key: None, tls_private_key: None,
admin_username: String::from("admin"),
admin_password: String::from("admin"),
}
}
}
#[derive(Deserialize, Debug, Clone, PartialEq)]
pub struct Pool {
pub pool_mode: String,
pub shards: HashMap<String, Shard>,
pub users: HashMap<String, User>,
pub default_role: String,
pub query_parser_enabled: bool,
pub primary_reads_enabled: bool,
pub sharding_function: String,
}
impl Default for Pool {
fn default() -> Pool {
Pool {
pool_mode: String::from("transaction"),
shards: HashMap::from([(String::from("1"), Shard::default())]),
users: HashMap::default(),
default_role: String::from("any"),
query_parser_enabled: false,
primary_reads_enabled: true,
sharding_function: "pg_bigint_hash".to_string(),
} }
} }
} }
@@ -137,8 +170,8 @@ impl Default for General {
/// Shard configuration. /// Shard configuration.
#[derive(Deserialize, Debug, Clone, PartialEq)] #[derive(Deserialize, Debug, Clone, PartialEq)]
pub struct Shard { pub struct Shard {
pub servers: Vec<(String, u16, String)>,
pub database: String, pub database: String,
pub servers: Vec<(String, u16, String)>,
} }
impl Default for Shard { impl Default for Shard {
@@ -150,26 +183,6 @@ impl Default for Shard {
} }
} }
/// Query Router configuration.
#[derive(Deserialize, Debug, Clone, PartialEq)]
pub struct QueryRouter {
pub default_role: String,
pub query_parser_enabled: bool,
pub primary_reads_enabled: bool,
pub sharding_function: String,
}
impl Default for QueryRouter {
fn default() -> QueryRouter {
QueryRouter {
default_role: String::from("any"),
query_parser_enabled: false,
primary_reads_enabled: true,
sharding_function: "pg_bigint_hash".to_string(),
}
}
}
fn default_path() -> String { fn default_path() -> String {
String::from("pgcat.toml") String::from("pgcat.toml")
} }
@@ -181,9 +194,7 @@ pub struct Config {
pub path: String, pub path: String,
pub general: General, pub general: General,
pub user: User, pub pools: HashMap<String, Pool>,
pub shards: HashMap<String, Shard>,
pub query_router: QueryRouter,
} }
impl Default for Config { impl Default for Config {
@@ -191,26 +202,58 @@ impl Default for Config {
Config { Config {
path: String::from("pgcat.toml"), path: String::from("pgcat.toml"),
general: General::default(), general: General::default(),
user: User::default(), pools: HashMap::default(),
shards: HashMap::from([(String::from("1"), Shard::default())]),
query_router: QueryRouter::default(),
} }
} }
} }
impl From<&Config> for std::collections::HashMap<String, String> { impl From<&Config> for std::collections::HashMap<String, String> {
fn from(config: &Config) -> HashMap<String, String> { fn from(config: &Config) -> HashMap<String, String> {
HashMap::from([ let mut r: Vec<(String, String)> = config
.pools
.iter()
.flat_map(|(pool_name, pool)| {
[
(
format!("pools.{}.pool_mode", pool_name),
pool.pool_mode.clone(),
),
(
format!("pools.{}.primary_reads_enabled", pool_name),
pool.primary_reads_enabled.to_string(),
),
(
format!("pools.{}.query_parser_enabled", pool_name),
pool.query_parser_enabled.to_string(),
),
(
format!("pools.{}.default_role", pool_name),
pool.default_role.clone(),
),
(
format!("pools.{}.sharding_function", pool_name),
pool.sharding_function.clone(),
),
(
format!("pools.{:?}.shard_count", pool_name),
pool.shards.len().to_string(),
),
(
format!("pools.{:?}.users", pool_name),
pool.users
.iter()
.map(|(_username, user)| &user.username)
.cloned()
.collect::<Vec<String>>()
.join(", "),
),
]
})
.collect();
let mut static_settings = vec![
("host".to_string(), config.general.host.to_string()), ("host".to_string(), config.general.host.to_string()),
("port".to_string(), config.general.port.to_string()), ("port".to_string(), config.general.port.to_string()),
(
"pool_size".to_string(),
config.general.pool_size.to_string(),
),
(
"pool_mode".to_string(),
config.general.pool_mode.to_string(),
),
( (
"connect_timeout".to_string(), "connect_timeout".to_string(),
config.general.connect_timeout.to_string(), config.general.connect_timeout.to_string(),
@@ -220,42 +263,22 @@ impl From<&Config> for std::collections::HashMap<String, String> {
config.general.healthcheck_timeout.to_string(), config.general.healthcheck_timeout.to_string(),
), ),
("ban_time".to_string(), config.general.ban_time.to_string()), ("ban_time".to_string(), config.general.ban_time.to_string()),
( ];
"default_role".to_string(),
config.query_router.default_role.to_string(), r.append(&mut static_settings);
), return r.iter().cloned().collect();
(
"query_parser_enabled".to_string(),
config.query_router.query_parser_enabled.to_string(),
),
(
"primary_reads_enabled".to_string(),
config.query_router.primary_reads_enabled.to_string(),
),
(
"sharding_function".to_string(),
config.query_router.sharding_function.to_string(),
),
])
} }
} }
impl Config { impl Config {
/// Print current configuration. /// Print current configuration.
pub fn show(&self) { pub fn show(&self) {
info!("Pool size: {}", self.general.pool_size);
info!("Pool mode: {}", self.general.pool_mode);
info!("Ban time: {}s", self.general.ban_time); info!("Ban time: {}s", self.general.ban_time);
info!( info!(
"Healthcheck timeout: {}ms", "Healthcheck timeout: {}ms",
self.general.healthcheck_timeout self.general.healthcheck_timeout
); );
info!("Connection timeout: {}ms", self.general.connect_timeout); info!("Connection timeout: {}ms", self.general.connect_timeout);
info!("Sharding function: {}", self.query_router.sharding_function);
info!("Primary reads: {}", self.query_router.primary_reads_enabled);
info!("Query router: {}", self.query_router.query_parser_enabled);
info!("Number of shards: {}", self.shards.len());
match self.general.tls_certificate.clone() { match self.general.tls_certificate.clone() {
Some(tls_certificate) => { Some(tls_certificate) => {
info!("TLS certificate: {}", tls_certificate); info!("TLS certificate: {}", tls_certificate);
@@ -274,6 +297,25 @@ impl Config {
info!("TLS support is disabled"); info!("TLS support is disabled");
} }
}; };
for (pool_name, pool_config) in &self.pools {
info!("--- Settings for pool {} ---", pool_name);
info!(
"Pool size from all users: {}",
pool_config
.users
.iter()
.map(|(_, user_cfg)| user_cfg.pool_size)
.sum::<u32>()
.to_string()
);
info!("Pool mode: {}", pool_config.pool_mode);
info!("Sharding function: {}", pool_config.sharding_function);
info!("Primary reads: {}", pool_config.primary_reads_enabled);
info!("Query router: {}", pool_config.query_parser_enabled);
info!("Number of shards: {}", pool_config.shards.len());
info!("Number of users: {}", pool_config.users.len());
}
} }
} }
@@ -311,88 +353,6 @@ pub async fn parse(path: &str) -> Result<(), Error> {
} }
}; };
match config.query_router.sharding_function.as_ref() {
"pg_bigint_hash" => (),
"sha1" => (),
_ => {
error!(
"Supported sharding functions are: 'pg_bigint_hash', 'sha1', got: '{}'",
config.query_router.sharding_function
);
return Err(Error::BadConfig);
}
};
// Quick config sanity check.
for shard in &config.shards {
// We use addresses as unique identifiers,
// let's make sure they are unique in the config as well.
let mut dup_check = HashSet::new();
let mut primary_count = 0;
match shard.0.parse::<usize>() {
Ok(_) => (),
Err(_) => {
error!(
"Shard '{}' is not a valid number, shards must be numbered starting at 0",
shard.0
);
return Err(Error::BadConfig);
}
};
if shard.1.servers.len() == 0 {
error!("Shard {} has no servers configured", shard.0);
return Err(Error::BadConfig);
}
for server in &shard.1.servers {
dup_check.insert(server);
// Check that we define only zero or one primary.
match server.2.as_ref() {
"primary" => primary_count += 1,
_ => (),
};
// Check role spelling.
match server.2.as_ref() {
"primary" => (),
"replica" => (),
_ => {
error!(
"Shard {} server role must be either 'primary' or 'replica', got: '{}'",
shard.0, server.2
);
return Err(Error::BadConfig);
}
};
}
if primary_count > 1 {
error!("Shard {} has more than on primary configured", &shard.0);
return Err(Error::BadConfig);
}
if dup_check.len() != shard.1.servers.len() {
error!("Shard {} contains duplicate server configs", &shard.0);
return Err(Error::BadConfig);
}
}
match config.query_router.default_role.as_ref() {
"any" => (),
"primary" => (),
"replica" => (),
other => {
error!(
"Query router default_role must be 'primary', 'replica', or 'any', got: '{}'",
other
);
return Err(Error::BadConfig);
}
};
// Validate TLS! // Validate TLS!
match config.general.tls_certificate.clone() { match config.general.tls_certificate.clone() {
Some(tls_certificate) => { Some(tls_certificate) => {
@@ -424,6 +384,90 @@ pub async fn parse(path: &str) -> Result<(), Error> {
None => (), None => (),
}; };
for (pool_name, pool) in &config.pools {
match pool.sharding_function.as_ref() {
"pg_bigint_hash" => (),
"sha1" => (),
_ => {
error!(
"Supported sharding functions are: 'pg_bigint_hash', 'sha1', got: '{}' in pool {} settings",
pool.sharding_function,
pool_name
);
return Err(Error::BadConfig);
}
};
match pool.default_role.as_ref() {
"any" => (),
"primary" => (),
"replica" => (),
other => {
error!(
"Query router default_role must be 'primary', 'replica', or 'any', got: '{}'",
other
);
return Err(Error::BadConfig);
}
};
for shard in &pool.shards {
// We use addresses as unique identifiers,
// let's make sure they are unique in the config as well.
let mut dup_check = HashSet::new();
let mut primary_count = 0;
match shard.0.parse::<usize>() {
Ok(_) => (),
Err(_) => {
error!(
"Shard '{}' is not a valid number, shards must be numbered starting at 0",
shard.0
);
return Err(Error::BadConfig);
}
};
if shard.1.servers.len() == 0 {
error!("Shard {} has no servers configured", shard.0);
return Err(Error::BadConfig);
}
for server in &shard.1.servers {
dup_check.insert(server);
// Check that we define only zero or one primary.
match server.2.as_ref() {
"primary" => primary_count += 1,
_ => (),
};
// Check role spelling.
match server.2.as_ref() {
"primary" => (),
"replica" => (),
_ => {
error!(
"Shard {} server role must be either 'primary' or 'replica', got: '{}'",
shard.0, server.2
);
return Err(Error::BadConfig);
}
};
}
if primary_count > 1 {
error!("Shard {} has more than on primary configured", &shard.0);
return Err(Error::BadConfig);
}
if dup_check.len() != shard.1.servers.len() {
error!("Shard {} contains duplicate server configs", &shard.0);
return Err(Error::BadConfig);
}
}
}
config.path = path.to_string(); config.path = path.to_string();
// Update the configuration globally. // Update the configuration globally.
@@ -434,7 +478,6 @@ pub async fn parse(path: &str) -> Result<(), Error> {
pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, Error> { pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, Error> {
let old_config = get_config(); let old_config = get_config();
match parse(&old_config.path).await { match parse(&old_config.path).await {
Ok(()) => (), Ok(()) => (),
Err(err) => { Err(err) => {
@@ -442,11 +485,10 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, E
return Err(Error::BadConfig); return Err(Error::BadConfig);
} }
}; };
let new_config = get_config(); let new_config = get_config();
if old_config.shards != new_config.shards || old_config.user != new_config.user { if old_config.pools != new_config.pools {
info!("Sharding configuration changed, re-creating server pools"); info!("Pool configuration changed, re-creating server pools");
ConnectionPool::from_config(client_server_map).await?; ConnectionPool::from_config(client_server_map).await?;
Ok(true) Ok(true)
} else if old_config != new_config { } else if old_config != new_config {
@@ -463,11 +505,58 @@ mod test {
#[tokio::test] #[tokio::test]
async fn test_config() { async fn test_config() {
parse("pgcat.toml").await.unwrap(); parse("pgcat.toml").await.unwrap();
assert_eq!(get_config().general.pool_size, 15);
assert_eq!(get_config().shards.len(), 3);
assert_eq!(get_config().shards["1"].servers[0].0, "127.0.0.1");
assert_eq!(get_config().shards["0"].servers[0].2, "primary");
assert_eq!(get_config().query_router.default_role, "any");
assert_eq!(get_config().path, "pgcat.toml".to_string()); assert_eq!(get_config().path, "pgcat.toml".to_string());
assert_eq!(get_config().general.ban_time, 60);
assert_eq!(get_config().pools.len(), 2);
assert_eq!(get_config().pools["sharded"].shards.len(), 3);
assert_eq!(get_config().pools["simple_db"].shards.len(), 1);
assert_eq!(get_config().pools["sharded"].users.len(), 2);
assert_eq!(get_config().pools["simple_db"].users.len(), 1);
assert_eq!(
get_config().pools["sharded"].shards["0"].servers[0].0,
"127.0.0.1"
);
assert_eq!(
get_config().pools["sharded"].shards["1"].servers[0].2,
"primary"
);
assert_eq!(get_config().pools["sharded"].shards["1"].database, "shard1");
assert_eq!(
get_config().pools["sharded"].users["0"].username,
"sharding_user"
);
assert_eq!(
get_config().pools["sharded"].users["1"].password,
"other_user"
);
assert_eq!(get_config().pools["sharded"].users["1"].pool_size, 21);
assert_eq!(get_config().pools["sharded"].default_role, "any");
assert_eq!(
get_config().pools["simple_db"].shards["0"].servers[0].0,
"127.0.0.1"
);
assert_eq!(
get_config().pools["simple_db"].shards["0"].servers[0].1,
5432
);
assert_eq!(
get_config().pools["simple_db"].shards["0"].database,
"some_db"
);
assert_eq!(get_config().pools["simple_db"].default_role, "primary");
assert_eq!(
get_config().pools["simple_db"].users["0"].username,
"simple_user"
);
assert_eq!(
get_config().pools["simple_db"].users["0"].password,
"simple_user"
);
assert_eq!(get_config().pools["simple_db"].users["0"].pool_size, 5);
} }
} }

View File

@@ -66,10 +66,12 @@ use config::{get_config, reload_config};
use pool::{ClientServerMap, ConnectionPool}; use pool::{ClientServerMap, ConnectionPool};
use stats::{Collector, Reporter, REPORTER}; use stats::{Collector, Reporter, REPORTER};
use crate::config::VERSION;
#[tokio::main(worker_threads = 4)] #[tokio::main(worker_threads = 4)]
async fn main() { async fn main() {
env_logger::init(); env_logger::init();
info!("Welcome to PgCat! Meow."); info!("Welcome to PgCat! Meow. (Version {})", VERSION);
if !query_router::QueryRouter::setup() { if !query_router::QueryRouter::setup() {
error!("Could not setup query router"); error!("Could not setup query router");

View File

@@ -10,19 +10,43 @@ use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
use crate::config::{get_config, Address, Role, User}; use crate::config::{get_config, Address, Role, Shard, User};
use crate::errors::Error; use crate::errors::Error;
use crate::server::Server; use crate::server::Server;
use crate::stats::{get_reporter, Reporter}; use crate::stats::{get_reporter, Reporter};
pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>; pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>; pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>;
pub type PoolMap = HashMap<(String, String), ConnectionPool>;
/// The connection pool, globally available. /// The connection pool, globally available.
/// This is atomic and safe and read-optimized. /// This is atomic and safe and read-optimized.
/// The pool is recreated dynamically when the config is reloaded. /// The pool is recreated dynamically when the config is reloaded.
pub static POOL: Lazy<ArcSwap<ConnectionPool>> = pub static POOLS: Lazy<ArcSwap<PoolMap>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default()));
Lazy::new(|| ArcSwap::from_pointee(ConnectionPool::default()));
#[derive(Clone, Debug)]
pub struct PoolSettings {
pub pool_mode: String,
pub shards: HashMap<String, Shard>,
pub user: User,
pub default_role: String,
pub query_parser_enabled: bool,
pub primary_reads_enabled: bool,
pub sharding_function: String,
}
impl Default for PoolSettings {
fn default() -> PoolSettings {
PoolSettings {
pool_mode: String::from("transaction"),
shards: HashMap::from([(String::from("1"), Shard::default())]),
user: User::default(),
default_role: String::from("any"),
query_parser_enabled: false,
primary_reads_enabled: true,
sharding_function: "pg_bigint_hash".to_string(),
}
}
}
/// The globally accessible connection pool. /// The globally accessible connection pool.
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
@@ -46,107 +70,124 @@ pub struct ConnectionPool {
/// clients on startup. We pre-connect to all shards and replicas /// clients on startup. We pre-connect to all shards and replicas
/// on pool creation and save the K messages here. /// on pool creation and save the K messages here.
server_info: BytesMut, server_info: BytesMut,
pub settings: PoolSettings,
} }
impl ConnectionPool { impl ConnectionPool {
/// Construct the connection pool from the configuration. /// Construct the connection pool from the configuration.
pub async fn from_config(client_server_map: ClientServerMap) -> Result<(), Error> { pub async fn from_config(client_server_map: ClientServerMap) -> Result<(), Error> {
let reporter = get_reporter();
let config = get_config(); let config = get_config();
let mut new_pools = PoolMap::default();
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
let mut address_id = 0; let mut address_id = 0;
let mut shard_ids = config for (pool_name, pool_config) in &config.pools {
.shards for (_user_index, user_info) in &pool_config.users {
.clone() let mut shards = Vec::new();
.into_keys() let mut addresses = Vec::new();
.map(|x| x.to_string()) let mut banlist = Vec::new();
.collect::<Vec<String>>(); let mut shard_ids = pool_config
.shards
.clone()
.into_keys()
.map(|x| x.to_string())
.collect::<Vec<String>>();
// Sort by shard number to ensure consistency. // Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap()); shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
for shard_idx in shard_ids { for shard_idx in shard_ids {
let shard = &config.shards[&shard_idx]; let shard = &pool_config.shards[&shard_idx];
let mut pools = Vec::new(); let mut pools = Vec::new();
let mut servers = Vec::new(); let mut servers = Vec::new();
let mut replica_number = 0; let mut replica_number = 0;
for server in shard.servers.iter() { for server in shard.servers.iter() {
let role = match server.2.as_ref() { let role = match server.2.as_ref() {
"primary" => Role::Primary, "primary" => Role::Primary,
"replica" => Role::Replica, "replica" => Role::Replica,
_ => { _ => {
error!("Config error: server role can be 'primary' or 'replica', have: '{}'. Defaulting to 'replica'.", server.2); error!("Config error: server role can be 'primary' or 'replica', have: '{}'. Defaulting to 'replica'.", server.2);
Role::Replica Role::Replica
}
};
let address = Address {
id: address_id,
database: pool_name.clone(),
host: server.0.clone(),
port: server.1.to_string(),
role: role,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
};
address_id += 1;
if role == Role::Replica {
replica_number += 1;
}
let manager = ServerPool::new(
address.clone(),
user_info.clone(),
&shard.database,
client_server_map.clone(),
get_reporter(),
);
let pool = Pool::builder()
.max_size(user_info.pool_size)
.connection_timeout(std::time::Duration::from_millis(
config.general.connect_timeout,
))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
pools.push(pool);
servers.push(address);
} }
};
let address = Address { shards.push(pools);
id: address_id, addresses.push(servers);
host: server.0.clone(), banlist.push(HashMap::new());
port: server.1.to_string(),
role: role,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
};
address_id += 1;
if role == Role::Replica {
replica_number += 1;
} }
let manager = ServerPool::new( assert_eq!(shards.len(), addresses.len());
address.clone(),
config.user.clone(),
&shard.database,
client_server_map.clone(),
reporter.clone(),
);
let pool = Pool::builder() let mut pool = ConnectionPool {
.max_size(config.general.pool_size) databases: shards,
.connection_timeout(std::time::Duration::from_millis( addresses: addresses,
config.general.connect_timeout, banlist: Arc::new(RwLock::new(banlist)),
)) stats: get_reporter(),
.test_on_check_out(false) server_info: BytesMut::new(),
.build(manager) settings: PoolSettings {
.await pool_mode: pool_config.pool_mode.clone(),
.unwrap(); shards: pool_config.shards.clone(),
user: user_info.clone(),
default_role: pool_config.default_role.clone(),
query_parser_enabled: pool_config.query_parser_enabled.clone(),
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function.clone(),
},
};
pools.push(pool); // Connect to the servers to make sure pool configuration is valid
servers.push(address); // before setting it globally.
match pool.validate().await {
Ok(_) => (),
Err(err) => {
error!("Could not validate connection pool: {:?}", err);
return Err(err);
}
};
new_pools.insert((pool_name.clone(), user_info.username.clone()), pool);
} }
shards.push(pools);
addresses.push(servers);
banlist.push(HashMap::new());
} }
assert_eq!(shards.len(), addresses.len()); POOLS.store(Arc::new(new_pools.clone()));
let mut pool = ConnectionPool {
databases: shards,
addresses: addresses,
banlist: Arc::new(RwLock::new(banlist)),
stats: reporter,
server_info: BytesMut::new(),
};
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
match pool.validate().await {
Ok(_) => (),
Err(err) => {
error!("Could not validate connection pool: {:?}", err);
return Err(err);
}
};
POOL.store(Arc::new(pool.clone()));
Ok(()) Ok(())
} }
@@ -474,7 +515,7 @@ impl ManageConnection for ServerPool {
info!( info!(
"Creating a new connection to {:?} using user {:?}", "Creating a new connection to {:?} using user {:?}",
self.address.name(), self.address.name(),
self.user.name self.user.username
); );
// Put a temporary process_id into the stats // Put a temporary process_id into the stats
@@ -517,6 +558,20 @@ impl ManageConnection for ServerPool {
} }
/// Get the connection pool /// Get the connection pool
pub fn get_pool() -> ConnectionPool { pub fn get_pool(db: String, user: String) -> Option<ConnectionPool> {
(*(*POOL.load())).clone() match get_all_pools().get(&(db, user)) {
Some(pool) => Some(pool.clone()),
None => None,
}
}
pub fn get_number_of_addresses() -> usize {
get_all_pools()
.iter()
.map(|(_, pool)| pool.databases())
.sum()
}
pub fn get_all_pools() -> HashMap<(String, String), ConnectionPool> {
return (*(*POOLS.load())).clone();
} }

View File

@@ -8,7 +8,8 @@ use sqlparser::ast::Statement::{Query, StartTransaction};
use sqlparser::dialect::PostgreSqlDialect; use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser; use sqlparser::parser::Parser;
use crate::config::{get_config, Role}; use crate::config::Role;
use crate::pool::{ConnectionPool, PoolSettings};
use crate::sharding::{Sharder, ShardingFunction}; use crate::sharding::{Sharder, ShardingFunction};
/// Regexes used to parse custom commands. /// Regexes used to parse custom commands.
@@ -53,6 +54,8 @@ pub struct QueryRouter {
/// Include the primary into the replica pool for reads. /// Include the primary into the replica pool for reads.
primary_reads_enabled: bool, primary_reads_enabled: bool,
pool_settings: PoolSettings,
} }
impl QueryRouter { impl QueryRouter {
@@ -88,14 +91,13 @@ impl QueryRouter {
} }
/// Create a new instance of the query router. Each client gets its own. /// Create a new instance of the query router. Each client gets its own.
pub fn new() -> QueryRouter { pub fn new(target_pool: ConnectionPool) -> QueryRouter {
let config = get_config();
QueryRouter { QueryRouter {
active_shard: None, active_shard: None,
active_role: None, active_role: None,
query_parser_enabled: config.query_router.query_parser_enabled, query_parser_enabled: target_pool.settings.query_parser_enabled,
primary_reads_enabled: config.query_router.primary_reads_enabled, primary_reads_enabled: target_pool.settings.primary_reads_enabled,
pool_settings: target_pool.settings,
} }
} }
@@ -130,15 +132,13 @@ impl QueryRouter {
return None; return None;
} }
let config = get_config(); let sharding_function = match self.pool_settings.sharding_function.as_ref() {
let sharding_function = match config.query_router.sharding_function.as_ref() {
"pg_bigint_hash" => ShardingFunction::PgBigintHash, "pg_bigint_hash" => ShardingFunction::PgBigintHash,
"sha1" => ShardingFunction::Sha1, "sha1" => ShardingFunction::Sha1,
_ => unreachable!(), _ => unreachable!(),
}; };
let default_server_role = match config.query_router.default_role.as_ref() { let default_server_role = match self.pool_settings.default_role.as_ref() {
"any" => None, "any" => None,
"primary" => Some(Role::Primary), "primary" => Some(Role::Primary),
"replica" => Some(Role::Replica), "replica" => Some(Role::Replica),
@@ -196,7 +196,7 @@ impl QueryRouter {
match command { match command {
Command::SetShardingKey => { Command::SetShardingKey => {
let sharder = Sharder::new(config.shards.len(), sharding_function); let sharder = Sharder::new(self.pool_settings.shards.len(), sharding_function);
let shard = sharder.shard(value.parse::<i64>().unwrap()); let shard = sharder.shard(value.parse::<i64>().unwrap());
self.active_shard = Some(shard); self.active_shard = Some(shard);
value = shard.to_string(); value = shard.to_string();
@@ -204,7 +204,7 @@ impl QueryRouter {
Command::SetShard => { Command::SetShard => {
self.active_shard = match value.to_ascii_uppercase().as_ref() { self.active_shard = match value.to_ascii_uppercase().as_ref() {
"ANY" => Some(rand::random::<usize>() % config.shards.len()), "ANY" => Some(rand::random::<usize>() % self.pool_settings.shards.len()),
_ => Some(value.parse::<usize>().unwrap()), _ => Some(value.parse::<usize>().unwrap()),
}; };
} }
@@ -233,7 +233,7 @@ impl QueryRouter {
"default" => { "default" => {
self.active_role = default_server_role; self.active_role = default_server_role;
self.query_parser_enabled = config.query_router.query_parser_enabled; self.query_parser_enabled = self.query_parser_enabled;
self.active_role self.active_role
} }
@@ -250,7 +250,7 @@ impl QueryRouter {
self.primary_reads_enabled = false; self.primary_reads_enabled = false;
} else if value == "default" { } else if value == "default" {
debug!("Setting primary reads to default"); debug!("Setting primary reads to default");
self.primary_reads_enabled = config.query_router.primary_reads_enabled; self.primary_reads_enabled = self.pool_settings.primary_reads_enabled;
} }
} }
@@ -370,7 +370,7 @@ mod test {
#[test] #[test]
fn test_defaults() { fn test_defaults() {
QueryRouter::setup(); QueryRouter::setup();
let qr = QueryRouter::new(); let qr = QueryRouter::new(ConnectionPool::default());
assert_eq!(qr.role(), None); assert_eq!(qr.role(), None);
} }
@@ -378,7 +378,7 @@ mod test {
#[test] #[test]
fn test_infer_role_replica() { fn test_infer_role_replica() {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new(ConnectionPool::default());
assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None); assert!(qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")) != None);
assert_eq!(qr.query_parser_enabled(), true); assert_eq!(qr.query_parser_enabled(), true);
@@ -402,7 +402,7 @@ mod test {
#[test] #[test]
fn test_infer_role_primary() { fn test_infer_role_primary() {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new(ConnectionPool::default());
let queries = vec![ let queries = vec![
simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"), simple_query("UPDATE items SET name = 'pumpkin' WHERE id = 5"),
@@ -421,7 +421,7 @@ mod test {
#[test] #[test]
fn test_infer_role_primary_reads_enabled() { fn test_infer_role_primary_reads_enabled() {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new(ConnectionPool::default());
let query = simple_query("SELECT * FROM items WHERE id = 5"); let query = simple_query("SELECT * FROM items WHERE id = 5");
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO on")) != None);
@@ -432,7 +432,7 @@ mod test {
#[test] #[test]
fn test_infer_role_parse_prepared() { fn test_infer_role_parse_prepared() {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new(ConnectionPool::default());
qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'")); qr.try_execute_command(simple_query("SET SERVER ROLE TO 'auto'"));
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);
@@ -523,15 +523,15 @@ mod test {
#[test] #[test]
fn test_try_execute_command() { fn test_try_execute_command() {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new(ConnectionPool::default());
// SetShardingKey // SetShardingKey
let query = simple_query("SET SHARDING KEY TO 13"); let query = simple_query("SET SHARDING KEY TO 13");
assert_eq!( assert_eq!(
qr.try_execute_command(query), qr.try_execute_command(query),
Some((Command::SetShardingKey, String::from("1"))) Some((Command::SetShardingKey, String::from("0")))
); );
assert_eq!(qr.shard(), 1); assert_eq!(qr.shard(), 0);
// SetShard // SetShard
let query = simple_query("SET SHARD TO '1'"); let query = simple_query("SET SHARD TO '1'");
@@ -600,7 +600,7 @@ mod test {
#[test] #[test]
fn test_enable_query_parser() { fn test_enable_query_parser() {
QueryRouter::setup(); QueryRouter::setup();
let mut qr = QueryRouter::new(); let mut qr = QueryRouter::new(ConnectionPool::default());
let query = simple_query("SET SERVER ROLE TO 'auto'"); let query = simple_query("SET SERVER ROLE TO 'auto'");
assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None); assert!(qr.try_execute_command(simple_query("SET PRIMARY READS TO off")) != None);

View File

@@ -82,7 +82,7 @@ impl Server {
trace!("Sending StartupMessage"); trace!("Sending StartupMessage");
// StartupMessage // StartupMessage
startup(&mut stream, &user.name, database).await?; startup(&mut stream, &user.username, database).await?;
let mut server_info = BytesMut::new(); let mut server_info = BytesMut::new();
let mut process_id: i32 = 0; let mut process_id: i32 = 0;
@@ -127,7 +127,7 @@ impl Server {
Err(_) => return Err(Error::SocketError), Err(_) => return Err(Error::SocketError),
}; };
md5_password(&mut stream, &user.name, &user.password, &salt[..]) md5_password(&mut stream, &user.username, &user.password, &salt[..])
.await?; .await?;
} }

View File

@@ -6,7 +6,7 @@ use parking_lot::Mutex;
use std::collections::HashMap; use std::collections::HashMap;
use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::sync::mpsc::{channel, Receiver, Sender};
use crate::pool::get_pool; use crate::pool::get_number_of_addresses;
pub static REPORTER: Lazy<ArcSwap<Reporter>> = pub static REPORTER: Lazy<ArcSwap<Reporter>> =
Lazy::new(|| ArcSwap::from_pointee(Reporter::default())); Lazy::new(|| ArcSwap::from_pointee(Reporter::default()));
@@ -331,8 +331,8 @@ impl Collector {
tokio::time::interval(tokio::time::Duration::from_millis(STAT_PERIOD / 15)); tokio::time::interval(tokio::time::Duration::from_millis(STAT_PERIOD / 15));
loop { loop {
interval.tick().await; interval.tick().await;
let addresses = get_pool().databases(); let address_count = get_number_of_addresses();
for address_id in 0..addresses { for address_id in 0..address_count {
let _ = tx.try_send(Event { let _ = tx.try_send(Event {
name: EventName::UpdateStats, name: EventName::UpdateStats,
value: 0, value: 0,
@@ -349,8 +349,8 @@ impl Collector {
tokio::time::interval(tokio::time::Duration::from_millis(STAT_PERIOD)); tokio::time::interval(tokio::time::Duration::from_millis(STAT_PERIOD));
loop { loop {
interval.tick().await; interval.tick().await;
let addresses = get_pool().databases(); let address_count = get_number_of_addresses();
for address_id in 0..addresses { for address_id in 0..address_count {
let _ = tx.try_send(Event { let _ = tx.try_send(Event {
name: EventName::UpdateAverages, name: EventName::UpdateAverages,
value: 0, value: 0,

View File

@@ -15,7 +15,7 @@ ActiveRecord::Base.establish_connection(
port: 6432, port: 6432,
username: 'sharding_user', username: 'sharding_user',
password: 'sharding_user', password: 'sharding_user',
database: 'rails_dev', database: 'sharded_db',
application_name: 'testing_pgcat', application_name: 'testing_pgcat',
prepared_statements: false, # Transaction mode prepared_statements: false, # Transaction mode
advisory_locks: false # Same advisory_locks: false # Same
@@ -117,7 +117,7 @@ end
# Test evil clients # Test evil clients
def poorly_behaved_client def poorly_behaved_client
conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/rails_dev?application_name=testing_pgcat") conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db?application_name=testing_pgcat")
conn.async_exec 'BEGIN' conn.async_exec 'BEGIN'
conn.async_exec 'SELECT 1' conn.async_exec 'SELECT 1'