mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-24 17:56:29 +00:00
Add dns_cache for server addresses as in pgbouncer (#249)
* Add dns_cache so server addresses are cached and invalidated when DNS changes. Adds a module to deal with dns_cache feature. It's main struct is CachedResolver, which is a simple thread safe hostname <-> Ips cache with the ability to refresh resolutions every `dns_max_ttl` seconds. This way, a client can check whether its ip address has changed. * Allow reloading dns cached * Add documentation for dns_cached
This commit is contained in:
410
src/dns_cache.rs
Normal file
410
src/dns_cache.rs
Normal file
@@ -0,0 +1,410 @@
|
||||
use crate::config::get_config;
|
||||
use crate::errors::Error;
|
||||
use arc_swap::ArcSwap;
|
||||
use log::{debug, error, info, warn};
|
||||
use once_cell::sync::Lazy;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::io;
|
||||
use std::net::IpAddr;
|
||||
use std::sync::Arc;
|
||||
use std::sync::RwLock;
|
||||
use tokio::time::{sleep, Duration};
|
||||
use trust_dns_resolver::error::{ResolveError, ResolveResult};
|
||||
use trust_dns_resolver::lookup_ip::LookupIp;
|
||||
use trust_dns_resolver::TokioAsyncResolver;
|
||||
|
||||
/// Cached Resolver Globally available
|
||||
pub static CACHED_RESOLVER: Lazy<ArcSwap<CachedResolver>> =
|
||||
Lazy::new(|| ArcSwap::from_pointee(CachedResolver::default()));
|
||||
|
||||
// Ip addressed are returned as a set of addresses
|
||||
// so we can compare.
|
||||
#[derive(Clone, PartialEq, Debug)]
|
||||
pub struct AddrSet {
|
||||
set: HashSet<IpAddr>,
|
||||
}
|
||||
|
||||
impl AddrSet {
|
||||
fn new() -> AddrSet {
|
||||
AddrSet {
|
||||
set: HashSet::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<LookupIp> for AddrSet {
|
||||
fn from(lookup_ip: LookupIp) -> Self {
|
||||
let mut addr_set = AddrSet::new();
|
||||
for address in lookup_ip.iter() {
|
||||
addr_set.set.insert(address);
|
||||
}
|
||||
addr_set
|
||||
}
|
||||
}
|
||||
|
||||
///
|
||||
/// A CachedResolver is a DNS resolution cache mechanism with customizable expiration time.
|
||||
///
|
||||
/// The system works as follows:
|
||||
///
|
||||
/// When a host is to be resolved, if we have not resolved it before, a new resolution is
|
||||
/// executed and stored in the internal cache. Concurrently, every `dns_max_ttl` time, the
|
||||
/// cache is refreshed.
|
||||
///
|
||||
/// # Example:
|
||||
///
|
||||
/// ```
|
||||
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
|
||||
///
|
||||
/// # tokio_test::block_on(async {
|
||||
/// let config = CachedResolverConfig::default();
|
||||
/// let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
/// let addrset = resolver.lookup_ip("www.example.com.").await.unwrap();
|
||||
/// # })
|
||||
/// ```
|
||||
///
|
||||
/// // Now the ip resolution is stored in local cache and subsequent
|
||||
/// // calls will be returned from cache. Also, the cache is refreshed
|
||||
/// // and updated every 10 seconds.
|
||||
///
|
||||
/// // You can now check if an 'old' lookup differs from what it's currently
|
||||
/// // store in cache by using `has_changed`.
|
||||
/// resolver.has_changed("www.example.com.", addrset)
|
||||
#[derive(Default)]
|
||||
pub struct CachedResolver {
|
||||
// The configuration of the cached_resolver.
|
||||
config: CachedResolverConfig,
|
||||
|
||||
// This is the hash that contains the hash.
|
||||
data: Option<RwLock<HashMap<String, AddrSet>>>,
|
||||
|
||||
// The resolver to be used for DNS queries.
|
||||
resolver: Option<TokioAsyncResolver>,
|
||||
|
||||
// The RefreshLoop
|
||||
refresh_loop: RwLock<Option<tokio::task::JoinHandle<()>>>,
|
||||
}
|
||||
|
||||
///
|
||||
/// Configuration
|
||||
#[derive(Clone, Debug, Default, PartialEq)]
|
||||
pub struct CachedResolverConfig {
|
||||
/// Amount of time in secods that a resolved dns address is considered stale.
|
||||
dns_max_ttl: u64,
|
||||
|
||||
/// Enabled or disabled? (this is so we can reload config)
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl CachedResolverConfig {
|
||||
fn new(dns_max_ttl: u64, enabled: bool) -> Self {
|
||||
CachedResolverConfig {
|
||||
dns_max_ttl,
|
||||
enabled,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::config::Config> for CachedResolverConfig {
|
||||
fn from(config: crate::config::Config) -> Self {
|
||||
CachedResolverConfig::new(config.general.dns_max_ttl, config.general.dns_cache_enabled)
|
||||
}
|
||||
}
|
||||
|
||||
impl CachedResolver {
|
||||
///
|
||||
/// Returns a new Arc<CachedResolver> based on passed configuration.
|
||||
/// It also starts the loop that will refresh cache entries.
|
||||
///
|
||||
/// # Arguments:
|
||||
///
|
||||
/// * `config` - The `CachedResolverConfig` to be used to create the resolver.
|
||||
///
|
||||
/// # Example:
|
||||
///
|
||||
/// ```
|
||||
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
|
||||
///
|
||||
/// # tokio_test::block_on(async {
|
||||
/// let config = CachedResolverConfig::default();
|
||||
/// let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
/// # })
|
||||
/// ```
|
||||
///
|
||||
pub async fn new(
|
||||
config: CachedResolverConfig,
|
||||
data: Option<HashMap<String, AddrSet>>,
|
||||
) -> Result<Arc<Self>, io::Error> {
|
||||
// Construct a new Resolver with default configuration options
|
||||
let resolver = Some(TokioAsyncResolver::tokio_from_system_conf()?);
|
||||
|
||||
let data = if let Some(hash) = data {
|
||||
Some(RwLock::new(hash))
|
||||
} else {
|
||||
Some(RwLock::new(HashMap::new()))
|
||||
};
|
||||
|
||||
let instance = Arc::new(Self {
|
||||
config,
|
||||
resolver,
|
||||
data,
|
||||
refresh_loop: RwLock::new(None),
|
||||
});
|
||||
|
||||
if instance.enabled() {
|
||||
info!("Scheduling DNS refresh loop");
|
||||
let refresh_loop = tokio::task::spawn({
|
||||
let instance = instance.clone();
|
||||
async move {
|
||||
instance.refresh_dns_entries_loop().await;
|
||||
}
|
||||
});
|
||||
*(instance.refresh_loop.write().unwrap()) = Some(refresh_loop);
|
||||
}
|
||||
|
||||
Ok(instance)
|
||||
}
|
||||
|
||||
pub fn enabled(&self) -> bool {
|
||||
self.config.enabled
|
||||
}
|
||||
|
||||
// Schedules the refresher
|
||||
async fn refresh_dns_entries_loop(&self) {
|
||||
let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap();
|
||||
let interval = Duration::from_secs(self.config.dns_max_ttl);
|
||||
loop {
|
||||
debug!("Begin refreshing cached DNS addresses.");
|
||||
// To minimize the time we hold the lock, we first create
|
||||
// an array with keys.
|
||||
let mut hostnames: Vec<String> = Vec::new();
|
||||
{
|
||||
if let Some(ref data) = self.data {
|
||||
for hostname in data.read().unwrap().keys() {
|
||||
hostnames.push(hostname.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for hostname in hostnames.iter() {
|
||||
let addrset = self
|
||||
.fetch_from_cache(hostname.as_str())
|
||||
.expect("Could not obtain expected address from cache, this should not happen");
|
||||
|
||||
match resolver.lookup_ip(hostname).await {
|
||||
Ok(lookup_ip) => {
|
||||
let new_addrset = AddrSet::from(lookup_ip);
|
||||
debug!(
|
||||
"Obtained address for host ({}) -> ({:?})",
|
||||
hostname, new_addrset
|
||||
);
|
||||
|
||||
if addrset != new_addrset {
|
||||
debug!(
|
||||
"Addr changed from {:?} to {:?} updating cache.",
|
||||
addrset, new_addrset
|
||||
);
|
||||
self.store_in_cache(hostname, new_addrset);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
error!(
|
||||
"There was an error trying to resolv {}: ({}).",
|
||||
hostname, err
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
debug!("Finished refreshing cached DNS addresses.");
|
||||
sleep(interval).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a `AddrSet` given the specified hostname.
|
||||
///
|
||||
/// This method first tries to fetch the value from the cache, if it misses
|
||||
/// then it is resolved and stored in the cache. TTL from records is ignored.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `host` - A string slice referencing the hostname to be resolved.
|
||||
///
|
||||
/// # Example:
|
||||
///
|
||||
/// ```
|
||||
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
|
||||
///
|
||||
/// # tokio_test::block_on(async {
|
||||
/// let config = CachedResolverConfig::default();
|
||||
/// let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
/// let response = resolver.lookup_ip("www.google.com.");
|
||||
/// # })
|
||||
/// ```
|
||||
///
|
||||
pub async fn lookup_ip(&self, host: &str) -> ResolveResult<AddrSet> {
|
||||
debug!("Lookup up {} in cache", host);
|
||||
match self.fetch_from_cache(host) {
|
||||
Some(addr_set) => {
|
||||
debug!("Cache hit!");
|
||||
Ok(addr_set)
|
||||
}
|
||||
None => {
|
||||
debug!("Not found, executing a dns query!");
|
||||
if let Some(ref resolver) = self.resolver {
|
||||
let addr_set = AddrSet::from(resolver.lookup_ip(host).await?);
|
||||
debug!("Obtained: {:?}", addr_set);
|
||||
self.store_in_cache(host, addr_set.clone());
|
||||
Ok(addr_set)
|
||||
} else {
|
||||
Err(ResolveError::from("No resolver available"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Returns true if the stored host resolution differs from the AddrSet passed.
|
||||
pub fn has_changed(&self, host: &str, addr_set: &AddrSet) -> bool {
|
||||
if let Some(fetched_addr_set) = self.fetch_from_cache(host) {
|
||||
return fetched_addr_set != *addr_set;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Fetches an AddrSet from the inner cache adquiring the read lock.
|
||||
fn fetch_from_cache(&self, key: &str) -> Option<AddrSet> {
|
||||
if let Some(ref hash) = self.data {
|
||||
if let Some(addr_set) = hash.read().unwrap().get(key) {
|
||||
return Some(addr_set.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
// Sets up the global CACHED_RESOLVER static variable so we can globally use DNS
|
||||
// cache.
|
||||
pub async fn from_config() -> Result<(), Error> {
|
||||
let cached_resolver = CACHED_RESOLVER.load();
|
||||
let desired_config = CachedResolverConfig::from(get_config());
|
||||
|
||||
if cached_resolver.config != desired_config {
|
||||
if let Some(ref refresh_loop) = *(cached_resolver.refresh_loop.write().unwrap()) {
|
||||
warn!("Killing Dnscache refresh loop as its configuration is being reloaded");
|
||||
refresh_loop.abort()
|
||||
}
|
||||
let new_resolver = if let Some(ref data) = cached_resolver.data {
|
||||
let data = Some(data.read().unwrap().clone());
|
||||
CachedResolver::new(desired_config, data).await
|
||||
} else {
|
||||
CachedResolver::new(desired_config, None).await
|
||||
};
|
||||
|
||||
match new_resolver {
|
||||
Ok(ok) => {
|
||||
CACHED_RESOLVER.store(ok);
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
let message = format!("Error setting up cached_resolver. Error: {:?}, will continue without this feature.", err);
|
||||
Err(Error::DNSCachedError(message))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// Stores the AddrSet in cache adquiring the write lock.
|
||||
fn store_in_cache(&self, host: &str, addr_set: AddrSet) {
|
||||
if let Some(ref data) = self.data {
|
||||
data.write().unwrap().insert(host.to_string(), addr_set);
|
||||
} else {
|
||||
error!("Could not insert, Hash not initialized");
|
||||
}
|
||||
}
|
||||
}
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use trust_dns_resolver::error::ResolveError;
|
||||
|
||||
#[tokio::test]
|
||||
async fn new() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await;
|
||||
assert!(resolver.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn lookup_ip() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
let response = resolver.lookup_ip("www.google.com.").await;
|
||||
assert!(response.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn has_changed() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
let hostname = "www.google.com.";
|
||||
let response = resolver.lookup_ip(hostname).await;
|
||||
let addr_set = response.unwrap();
|
||||
assert!(!resolver.has_changed(hostname, &addr_set));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_host() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
let hostname = "www.idontexists.";
|
||||
let response = resolver.lookup_ip(hostname).await;
|
||||
assert!(matches!(response, Err(ResolveError { .. })));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn incorrect_address() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
let hostname = "w ww.idontexists.";
|
||||
let response = resolver.lookup_ip(hostname).await;
|
||||
assert!(matches!(response, Err(ResolveError { .. })));
|
||||
assert!(!resolver.has_changed(hostname, &AddrSet::new()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
// Ok, this test is based on the fact that google does DNS RR
|
||||
// and does not responds with every available ip everytime, so
|
||||
// if I cache here, it will miss after one cache iteration or two.
|
||||
async fn thread() {
|
||||
let config = CachedResolverConfig {
|
||||
dns_max_ttl: 10,
|
||||
enabled: true,
|
||||
};
|
||||
let resolver = CachedResolver::new(config, None).await.unwrap();
|
||||
let hostname = "www.google.com.";
|
||||
let response = resolver.lookup_ip(hostname).await;
|
||||
let addr_set = response.unwrap();
|
||||
assert!(!resolver.has_changed(hostname, &addr_set));
|
||||
let resolver_for_refresher = resolver.clone();
|
||||
let _thread_handle = tokio::task::spawn(async move {
|
||||
resolver_for_refresher.refresh_dns_entries_loop().await;
|
||||
});
|
||||
assert!(!resolver.has_changed(hostname, &addr_set));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user