Add support for multi-database / multi-user pools (#96)

* Add support for multi-database / multi-user pools

* Nothing

* cargo fmt

* CI

* remove test users

* rename pool

* Update tests to use admin user/pass

* more fixes

* Revert bad change

* Use PGDATABASE env var

* send server info in case of admin
This commit is contained in:
Mostafa Abdelraouf
2022-07-27 21:47:55 -05:00
committed by GitHub
parent c5be5565a5
commit 2ae4b438e3
14 changed files with 700 additions and 503 deletions

View File

@@ -10,19 +10,43 @@ use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use crate::config::{get_config, Address, Role, User};
use crate::config::{get_config, Address, Role, Shard, User};
use crate::errors::Error;
use crate::server::Server;
use crate::stats::{get_reporter, Reporter};
pub type BanList = Arc<RwLock<Vec<HashMap<Address, NaiveDateTime>>>>;
pub type ClientServerMap = Arc<Mutex<HashMap<(i32, i32), (i32, i32, String, String)>>>;
pub type PoolMap = HashMap<(String, String), ConnectionPool>;
/// The connection pool, globally available.
/// This is atomic and safe and read-optimized.
/// The pool is recreated dynamically when the config is reloaded.
pub static POOL: Lazy<ArcSwap<ConnectionPool>> =
Lazy::new(|| ArcSwap::from_pointee(ConnectionPool::default()));
pub static POOLS: Lazy<ArcSwap<PoolMap>> = Lazy::new(|| ArcSwap::from_pointee(HashMap::default()));
#[derive(Clone, Debug)]
pub struct PoolSettings {
pub pool_mode: String,
pub shards: HashMap<String, Shard>,
pub user: User,
pub default_role: String,
pub query_parser_enabled: bool,
pub primary_reads_enabled: bool,
pub sharding_function: String,
}
impl Default for PoolSettings {
fn default() -> PoolSettings {
PoolSettings {
pool_mode: String::from("transaction"),
shards: HashMap::from([(String::from("1"), Shard::default())]),
user: User::default(),
default_role: String::from("any"),
query_parser_enabled: false,
primary_reads_enabled: true,
sharding_function: "pg_bigint_hash".to_string(),
}
}
}
/// The globally accessible connection pool.
#[derive(Clone, Debug, Default)]
@@ -46,107 +70,124 @@ pub struct ConnectionPool {
/// clients on startup. We pre-connect to all shards and replicas
/// on pool creation and save the K messages here.
server_info: BytesMut,
pub settings: PoolSettings,
}
impl ConnectionPool {
/// Construct the connection pool from the configuration.
pub async fn from_config(client_server_map: ClientServerMap) -> Result<(), Error> {
let reporter = get_reporter();
let config = get_config();
let mut new_pools = PoolMap::default();
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
let mut address_id = 0;
let mut shard_ids = config
.shards
.clone()
.into_keys()
.map(|x| x.to_string())
.collect::<Vec<String>>();
for (pool_name, pool_config) in &config.pools {
for (_user_index, user_info) in &pool_config.users {
let mut shards = Vec::new();
let mut addresses = Vec::new();
let mut banlist = Vec::new();
let mut shard_ids = pool_config
.shards
.clone()
.into_keys()
.map(|x| x.to_string())
.collect::<Vec<String>>();
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
// Sort by shard number to ensure consistency.
shard_ids.sort_by_key(|k| k.parse::<i64>().unwrap());
for shard_idx in shard_ids {
let shard = &config.shards[&shard_idx];
let mut pools = Vec::new();
let mut servers = Vec::new();
let mut replica_number = 0;
for shard_idx in shard_ids {
let shard = &pool_config.shards[&shard_idx];
let mut pools = Vec::new();
let mut servers = Vec::new();
let mut replica_number = 0;
for server in shard.servers.iter() {
let role = match server.2.as_ref() {
"primary" => Role::Primary,
"replica" => Role::Replica,
_ => {
error!("Config error: server role can be 'primary' or 'replica', have: '{}'. Defaulting to 'replica'.", server.2);
Role::Replica
for server in shard.servers.iter() {
let role = match server.2.as_ref() {
"primary" => Role::Primary,
"replica" => Role::Replica,
_ => {
error!("Config error: server role can be 'primary' or 'replica', have: '{}'. Defaulting to 'replica'.", server.2);
Role::Replica
}
};
let address = Address {
id: address_id,
database: pool_name.clone(),
host: server.0.clone(),
port: server.1.to_string(),
role: role,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
};
address_id += 1;
if role == Role::Replica {
replica_number += 1;
}
let manager = ServerPool::new(
address.clone(),
user_info.clone(),
&shard.database,
client_server_map.clone(),
get_reporter(),
);
let pool = Pool::builder()
.max_size(user_info.pool_size)
.connection_timeout(std::time::Duration::from_millis(
config.general.connect_timeout,
))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
pools.push(pool);
servers.push(address);
}
};
let address = Address {
id: address_id,
host: server.0.clone(),
port: server.1.to_string(),
role: role,
replica_number,
shard: shard_idx.parse::<usize>().unwrap(),
};
address_id += 1;
if role == Role::Replica {
replica_number += 1;
shards.push(pools);
addresses.push(servers);
banlist.push(HashMap::new());
}
let manager = ServerPool::new(
address.clone(),
config.user.clone(),
&shard.database,
client_server_map.clone(),
reporter.clone(),
);
assert_eq!(shards.len(), addresses.len());
let pool = Pool::builder()
.max_size(config.general.pool_size)
.connection_timeout(std::time::Duration::from_millis(
config.general.connect_timeout,
))
.test_on_check_out(false)
.build(manager)
.await
.unwrap();
let mut pool = ConnectionPool {
databases: shards,
addresses: addresses,
banlist: Arc::new(RwLock::new(banlist)),
stats: get_reporter(),
server_info: BytesMut::new(),
settings: PoolSettings {
pool_mode: pool_config.pool_mode.clone(),
shards: pool_config.shards.clone(),
user: user_info.clone(),
default_role: pool_config.default_role.clone(),
query_parser_enabled: pool_config.query_parser_enabled.clone(),
primary_reads_enabled: pool_config.primary_reads_enabled,
sharding_function: pool_config.sharding_function.clone(),
},
};
pools.push(pool);
servers.push(address);
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
match pool.validate().await {
Ok(_) => (),
Err(err) => {
error!("Could not validate connection pool: {:?}", err);
return Err(err);
}
};
new_pools.insert((pool_name.clone(), user_info.username.clone()), pool);
}
shards.push(pools);
addresses.push(servers);
banlist.push(HashMap::new());
}
assert_eq!(shards.len(), addresses.len());
let mut pool = ConnectionPool {
databases: shards,
addresses: addresses,
banlist: Arc::new(RwLock::new(banlist)),
stats: reporter,
server_info: BytesMut::new(),
};
// Connect to the servers to make sure pool configuration is valid
// before setting it globally.
match pool.validate().await {
Ok(_) => (),
Err(err) => {
error!("Could not validate connection pool: {:?}", err);
return Err(err);
}
};
POOL.store(Arc::new(pool.clone()));
POOLS.store(Arc::new(new_pools.clone()));
Ok(())
}
@@ -474,7 +515,7 @@ impl ManageConnection for ServerPool {
info!(
"Creating a new connection to {:?} using user {:?}",
self.address.name(),
self.user.name
self.user.username
);
// Put a temporary process_id into the stats
@@ -517,6 +558,20 @@ impl ManageConnection for ServerPool {
}
/// Get the connection pool
pub fn get_pool() -> ConnectionPool {
(*(*POOL.load())).clone()
pub fn get_pool(db: String, user: String) -> Option<ConnectionPool> {
match get_all_pools().get(&(db, user)) {
Some(pool) => Some(pool.clone()),
None => None,
}
}
pub fn get_number_of_addresses() -> usize {
get_all_pools()
.iter()
.map(|(_, pool)| pool.databases())
.sum()
}
pub fn get_all_pools() -> HashMap<(String, String), ConnectionPool> {
return (*(*POOLS.load())).clone();
}