Add dns_cache for server addresses as in pgbouncer (#249)

* Add dns_cache so server addresses are cached and invalidated when DNS changes.

Adds a module to deal with dns_cache feature. It's
main struct is CachedResolver, which is a simple thread safe
hostname <-> Ips cache with the ability to refresh resolutions
every `dns_max_ttl` seconds. This way, a client can check whether its
ip address has changed.

* Allow reloading dns cached

* Add documentation for dns_cached
This commit is contained in:
Jose Fernández
2023-05-02 10:26:40 +02:00
committed by GitHub
parent 3601130ba1
commit 7dfbd993f2
10 changed files with 794 additions and 3 deletions

View File

@@ -188,6 +188,22 @@ default: "admin_pass"
Password to access the virtual administrative database Password to access the virtual administrative database
### dns_cache_enabled
```
path: general.dns_cache_enabled
default: false
```
When enabled, ip resolutions for server connections specified using hostnames will be cached
and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
old ip connections are closed (gracefully) and new connections will start using new ip.
### dns_max_ttl
```
path: general.dns_max_ttl
default: 30
```
Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
## `pools.<pool_name>` Section ## `pools.<pool_name>` Section
### pool_mode ### pool_mode

288
Cargo.lock generated
View File

@@ -26,6 +26,27 @@ version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6" checksum = "bddcadddf5e9015d310179a59bb28c4d4b9920ad0f11e8e14dbadf654890c9a6"
[[package]]
name = "async-stream"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dad5c83079eae9969be7fadefe640a1c566901f05ff91ab221de4b6f68d9507e"
dependencies = [
"async-stream-impl",
"futures-core",
]
[[package]]
name = "async-stream-impl"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10f203db73a71dfa2fb6dd22763990fa26f3d2625a6da2da900d23b87d26be27"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.68" version = "0.1.68"
@@ -212,6 +233,12 @@ dependencies = [
"syn 1.0.109", "syn 1.0.109",
] ]
[[package]]
name = "data-encoding"
version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3ee2393c4a91429dffb4bedf19f4d6abf27d8a732c8ce4980305d782e5426d57"
[[package]] [[package]]
name = "digest" name = "digest"
version = "0.10.6" version = "0.10.6"
@@ -223,6 +250,18 @@ dependencies = [
"subtle", "subtle",
] ]
[[package]]
name = "enum-as-inner"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c9720bba047d567ffc8a3cba48bf19126600e249ab7f128e9233e6376976a116"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "env_logger" name = "env_logger"
version = "0.10.0" version = "0.10.0"
@@ -275,6 +314,15 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "form_urlencoded"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8"
dependencies = [
"percent-encoding",
]
[[package]] [[package]]
name = "futures" name = "futures"
version = "0.3.28" version = "0.3.28"
@@ -410,6 +458,12 @@ version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "heck"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2540771e65fc8cb83cd6e8a237f70c319bd5c29f78ed1084ba5d50eeac86f7f9"
[[package]] [[package]]
name = "hermit-abi" name = "hermit-abi"
version = "0.2.6" version = "0.2.6"
@@ -434,6 +488,17 @@ dependencies = [
"digest", "digest",
] ]
[[package]]
name = "hostname"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c731c3e10504cc8ed35cfe2f1db4c9274c3d35fa486e3b31df46f068ef3e867"
dependencies = [
"libc",
"match_cfg",
"winapi",
]
[[package]] [[package]]
name = "http" name = "http"
version = "0.2.9" version = "0.2.9"
@@ -522,6 +587,27 @@ dependencies = [
"cxx-build", "cxx-build",
] ]
[[package]]
name = "idna"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8"
dependencies = [
"matches",
"unicode-bidi",
"unicode-normalization",
]
[[package]]
name = "idna"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6"
dependencies = [
"unicode-bidi",
"unicode-normalization",
]
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "1.9.2" version = "1.9.2"
@@ -542,6 +628,24 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "ipconfig"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd302af1b90f2463a98fa5ad469fc212c8e3175a41c3068601bfa2727591c5be"
dependencies = [
"socket2",
"widestring",
"winapi",
"winreg",
]
[[package]]
name = "ipnet"
version = "2.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f88c5561171189e69df9d98bcf18fd5f9558300f7ea7b801eb8a0fd748bd8745"
[[package]] [[package]]
name = "is-terminal" name = "is-terminal"
version = "0.4.4" version = "0.4.4"
@@ -589,6 +693,12 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "lazy_static"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.139" version = "0.2.139"
@@ -604,6 +714,12 @@ dependencies = [
"cc", "cc",
] ]
[[package]]
name = "linked-hash-map"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f"
[[package]] [[package]]
name = "linux-raw-sys" name = "linux-raw-sys"
version = "0.1.4" version = "0.1.4"
@@ -629,6 +745,27 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "lru-cache"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31e24f1ad8321ca0e8a1e0ac13f23cb668e6f5466c2c57319f6a5cf1cc8e3b1c"
dependencies = [
"linked-hash-map",
]
[[package]]
name = "match_cfg"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffbee8634e0d45d258acb448e7eaab3fce7a0a467395d4d9f228e3c1f01fb2e4"
[[package]]
name = "matches"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f"
[[package]] [[package]]
name = "md-5" name = "md-5"
version = "0.10.5" version = "0.10.5"
@@ -737,6 +874,12 @@ dependencies = [
"windows-sys", "windows-sys",
] ]
[[package]]
name = "percent-encoding"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e"
[[package]] [[package]]
name = "pgcat" name = "pgcat"
version = "1.0.1" version = "1.0.1"
@@ -777,7 +920,9 @@ dependencies = [
"stringprep", "stringprep",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"tokio-test",
"toml", "toml",
"trust-dns-resolver",
"webpki-roots", "webpki-roots",
] ]
@@ -888,6 +1033,12 @@ dependencies = [
"unicode-ident", "unicode-ident",
] ]
[[package]]
name = "quick-error"
version = "1.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
[[package]] [[package]]
name = "quote" name = "quote"
version = "1.0.26" version = "1.0.26"
@@ -953,6 +1104,16 @@ version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6868896879ba532248f33598de5181522d8b3d9d724dfd230911e1a7d4822f5" checksum = "b6868896879ba532248f33598de5181522d8b3d9d724dfd230911e1a7d4822f5"
[[package]]
name = "resolv-conf"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "52e44394d2086d010551b14b53b1f24e31647570cd1deb0379e2c21b329aba00"
dependencies = [
"hostname",
"quick-error",
]
[[package]] [[package]]
name = "ring" name = "ring"
version = "0.16.20" version = "0.16.20"
@@ -1191,6 +1352,26 @@ dependencies = [
"winapi-util", "winapi-util",
] ]
[[package]]
name = "thiserror"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10deb33631e3c9018b9baf9dcbbc4f737320d2b576bac10f6aefa048fa407e3e"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "982d17546b47146b28f7c22e3d08465f6b8903d0ea13c1660d9d84a6e7adcdbb"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "time" name = "time"
version = "0.1.45" version = "0.1.45"
@@ -1258,6 +1439,30 @@ dependencies = [
"tokio", "tokio",
] ]
[[package]]
name = "tokio-stream"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d660770404473ccd7bc9f8b28494a811bc18542b915c0855c51e8f419d5223ce"
dependencies = [
"futures-core",
"pin-project-lite",
"tokio",
]
[[package]]
name = "tokio-test"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53474327ae5e166530d17f2d956afcb4f8a004de581b3cae10f12006bc8163e3"
dependencies = [
"async-stream",
"bytes",
"futures-core",
"tokio",
"tokio-stream",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.7.7" version = "0.7.7"
@@ -1320,9 +1525,21 @@ checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"pin-project-lite", "pin-project-lite",
"tracing-attributes",
"tracing-core", "tracing-core",
] ]
[[package]]
name = "tracing-attributes"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4017f8f45139870ca7e672686113917c71c7a6e02d4924eda67186083c03081a"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "tracing-core" name = "tracing-core"
version = "0.1.30" version = "0.1.30"
@@ -1332,6 +1549,51 @@ dependencies = [
"once_cell", "once_cell",
] ]
[[package]]
name = "trust-dns-proto"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4f7f83d1e4a0e4358ac54c5c3681e5d7da5efc5a7a632c90bb6d6669ddd9bc26"
dependencies = [
"async-trait",
"cfg-if",
"data-encoding",
"enum-as-inner",
"futures-channel",
"futures-io",
"futures-util",
"idna 0.2.3",
"ipnet",
"lazy_static",
"rand",
"smallvec",
"thiserror",
"tinyvec",
"tokio",
"tracing",
"url",
]
[[package]]
name = "trust-dns-resolver"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aff21aa4dcefb0a1afbfac26deb0adc93888c7d295fb63ab273ef276ba2b7cfe"
dependencies = [
"cfg-if",
"futures-util",
"ipconfig",
"lazy_static",
"lru-cache",
"parking_lot",
"resolv-conf",
"smallvec",
"thiserror",
"tokio",
"tracing",
"trust-dns-proto",
]
[[package]] [[package]]
name = "try-lock" name = "try-lock"
version = "0.2.4" version = "0.2.4"
@@ -1377,6 +1639,17 @@ version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
[[package]]
name = "url"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643"
dependencies = [
"form_urlencoded",
"idna 0.3.0",
"percent-encoding",
]
[[package]] [[package]]
name = "version_check" name = "version_check"
version = "0.9.4" version = "0.9.4"
@@ -1478,6 +1751,12 @@ dependencies = [
"rustls-webpki", "rustls-webpki",
] ]
[[package]]
name = "widestring"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17882f045410753661207383517a6f62ec3dbeb6a4ed2acce01f0728238d1983"
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.3.9" version = "0.3.9"
@@ -1583,3 +1862,12 @@ checksum = "faf09497b8f8b5ac5d3bb4d05c0a99be20f26fd3d5f2db7b0716e946d5103658"
dependencies = [ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "winreg"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
dependencies = [
"winapi",
]

View File

@@ -42,6 +42,8 @@ fallible-iterator = "0.2"
pin-project = "1" pin-project = "1"
webpki-roots = "0.23" webpki-roots = "0.23"
rustls = { version = "0.21", features = ["dangerous_configuration"] } rustls = { version = "0.21", features = ["dangerous_configuration"] }
trust-dns-resolver = "0.22.0"
tokio-test = "0.4.2"
[target.'cfg(not(target_env = "msvc"))'.dependencies] [target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.5.0" jemallocator = "0.5.0"

View File

@@ -146,6 +146,14 @@ idle_timeout = 40000
# Connect timeout can be overwritten in the pool # Connect timeout can be overwritten in the pool
connect_timeout = 3000 connect_timeout = 3000
# When enabled, ip resolutions for server connections specified using hostnames will be cached
# and checked for changes every `dns_max_ttl` seconds. If a change in the host resolution is found
# old ip connections are closed (gracefully) and new connections will start using new ip.
# dns_cache_enabled = false
# Specifies how often (in seconds) cached ip addresses for servers are rechecked (see `dns_cache_enabled`).
# dns_max_ttl = 30
# User configs are structured as pool.<pool_name>.users.<user_index> # User configs are structured as pool.<pool_name>.users.<user_index>
# This section holds the credentials for users that may connect to this cluster # This section holds the credentials for users that may connect to this cluster
[pools.sharded_db.users.0] [pools.sharded_db.users.0]

View File

@@ -12,6 +12,7 @@ use std::sync::Arc;
use tokio::fs::File; use tokio::fs::File;
use tokio::io::AsyncReadExt; use tokio::io::AsyncReadExt;
use crate::dns_cache::CachedResolver;
use crate::errors::Error; use crate::errors::Error;
use crate::pool::{ClientServerMap, ConnectionPool}; use crate::pool::{ClientServerMap, ConnectionPool};
use crate::sharding::ShardingFunction; use crate::sharding::ShardingFunction;
@@ -255,6 +256,12 @@ pub struct General {
#[serde(default)] // False #[serde(default)] // False
pub log_client_disconnections: bool, pub log_client_disconnections: bool,
#[serde(default)] // False
pub dns_cache_enabled: bool,
#[serde(default = "General::default_dns_max_ttl")]
pub dns_max_ttl: u64,
#[serde(default = "General::default_shutdown_timeout")] #[serde(default = "General::default_shutdown_timeout")]
pub shutdown_timeout: u64, pub shutdown_timeout: u64,
@@ -336,6 +343,10 @@ impl General {
60000 60000
} }
pub fn default_dns_max_ttl() -> u64 {
30
}
pub fn default_healthcheck_timeout() -> u64 { pub fn default_healthcheck_timeout() -> u64 {
1000 1000
} }
@@ -378,6 +389,8 @@ impl Default for General {
log_client_connections: false, log_client_connections: false,
log_client_disconnections: false, log_client_disconnections: false,
autoreload: None, autoreload: None,
dns_cache_enabled: false,
dns_max_ttl: Self::default_dns_max_ttl(),
tls_certificate: None, tls_certificate: None,
tls_private_key: None, tls_private_key: None,
server_tls: false, server_tls: false,
@@ -1119,6 +1132,10 @@ pub async fn reload_config(client_server_map: ClientServerMap) -> Result<bool, E
} }
}; };
let new_config = get_config(); let new_config = get_config();
match CachedResolver::from_config().await {
Ok(_) => (),
Err(err) => error!("DNS cache reinitialization error: {:?}", err),
};
if old_config.pools != new_config.pools { if old_config.pools != new_config.pools {
info!("Pool configuration changed"); info!("Pool configuration changed");

410
src/dns_cache.rs Normal file
View File

@@ -0,0 +1,410 @@
use crate::config::get_config;
use crate::errors::Error;
use arc_swap::ArcSwap;
use log::{debug, error, info, warn};
use once_cell::sync::Lazy;
use std::collections::{HashMap, HashSet};
use std::io;
use std::net::IpAddr;
use std::sync::Arc;
use std::sync::RwLock;
use tokio::time::{sleep, Duration};
use trust_dns_resolver::error::{ResolveError, ResolveResult};
use trust_dns_resolver::lookup_ip::LookupIp;
use trust_dns_resolver::TokioAsyncResolver;
/// Cached Resolver Globally available
pub static CACHED_RESOLVER: Lazy<ArcSwap<CachedResolver>> =
Lazy::new(|| ArcSwap::from_pointee(CachedResolver::default()));
// Ip addressed are returned as a set of addresses
// so we can compare.
#[derive(Clone, PartialEq, Debug)]
pub struct AddrSet {
set: HashSet<IpAddr>,
}
impl AddrSet {
fn new() -> AddrSet {
AddrSet {
set: HashSet::new(),
}
}
}
impl From<LookupIp> for AddrSet {
fn from(lookup_ip: LookupIp) -> Self {
let mut addr_set = AddrSet::new();
for address in lookup_ip.iter() {
addr_set.set.insert(address);
}
addr_set
}
}
///
/// A CachedResolver is a DNS resolution cache mechanism with customizable expiration time.
///
/// The system works as follows:
///
/// When a host is to be resolved, if we have not resolved it before, a new resolution is
/// executed and stored in the internal cache. Concurrently, every `dns_max_ttl` time, the
/// cache is refreshed.
///
/// # Example:
///
/// ```
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
///
/// # tokio_test::block_on(async {
/// let config = CachedResolverConfig::default();
/// let resolver = CachedResolver::new(config, None).await.unwrap();
/// let addrset = resolver.lookup_ip("www.example.com.").await.unwrap();
/// # })
/// ```
///
/// // Now the ip resolution is stored in local cache and subsequent
/// // calls will be returned from cache. Also, the cache is refreshed
/// // and updated every 10 seconds.
///
/// // You can now check if an 'old' lookup differs from what it's currently
/// // store in cache by using `has_changed`.
/// resolver.has_changed("www.example.com.", addrset)
#[derive(Default)]
pub struct CachedResolver {
// The configuration of the cached_resolver.
config: CachedResolverConfig,
// This is the hash that contains the hash.
data: Option<RwLock<HashMap<String, AddrSet>>>,
// The resolver to be used for DNS queries.
resolver: Option<TokioAsyncResolver>,
// The RefreshLoop
refresh_loop: RwLock<Option<tokio::task::JoinHandle<()>>>,
}
///
/// Configuration
#[derive(Clone, Debug, Default, PartialEq)]
pub struct CachedResolverConfig {
/// Amount of time in secods that a resolved dns address is considered stale.
dns_max_ttl: u64,
/// Enabled or disabled? (this is so we can reload config)
enabled: bool,
}
impl CachedResolverConfig {
fn new(dns_max_ttl: u64, enabled: bool) -> Self {
CachedResolverConfig {
dns_max_ttl,
enabled,
}
}
}
impl From<crate::config::Config> for CachedResolverConfig {
fn from(config: crate::config::Config) -> Self {
CachedResolverConfig::new(config.general.dns_max_ttl, config.general.dns_cache_enabled)
}
}
impl CachedResolver {
///
/// Returns a new Arc<CachedResolver> based on passed configuration.
/// It also starts the loop that will refresh cache entries.
///
/// # Arguments:
///
/// * `config` - The `CachedResolverConfig` to be used to create the resolver.
///
/// # Example:
///
/// ```
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
///
/// # tokio_test::block_on(async {
/// let config = CachedResolverConfig::default();
/// let resolver = CachedResolver::new(config, None).await.unwrap();
/// # })
/// ```
///
pub async fn new(
config: CachedResolverConfig,
data: Option<HashMap<String, AddrSet>>,
) -> Result<Arc<Self>, io::Error> {
// Construct a new Resolver with default configuration options
let resolver = Some(TokioAsyncResolver::tokio_from_system_conf()?);
let data = if let Some(hash) = data {
Some(RwLock::new(hash))
} else {
Some(RwLock::new(HashMap::new()))
};
let instance = Arc::new(Self {
config,
resolver,
data,
refresh_loop: RwLock::new(None),
});
if instance.enabled() {
info!("Scheduling DNS refresh loop");
let refresh_loop = tokio::task::spawn({
let instance = instance.clone();
async move {
instance.refresh_dns_entries_loop().await;
}
});
*(instance.refresh_loop.write().unwrap()) = Some(refresh_loop);
}
Ok(instance)
}
pub fn enabled(&self) -> bool {
self.config.enabled
}
// Schedules the refresher
async fn refresh_dns_entries_loop(&self) {
let resolver = TokioAsyncResolver::tokio_from_system_conf().unwrap();
let interval = Duration::from_secs(self.config.dns_max_ttl);
loop {
debug!("Begin refreshing cached DNS addresses.");
// To minimize the time we hold the lock, we first create
// an array with keys.
let mut hostnames: Vec<String> = Vec::new();
{
if let Some(ref data) = self.data {
for hostname in data.read().unwrap().keys() {
hostnames.push(hostname.clone());
}
}
}
for hostname in hostnames.iter() {
let addrset = self
.fetch_from_cache(hostname.as_str())
.expect("Could not obtain expected address from cache, this should not happen");
match resolver.lookup_ip(hostname).await {
Ok(lookup_ip) => {
let new_addrset = AddrSet::from(lookup_ip);
debug!(
"Obtained address for host ({}) -> ({:?})",
hostname, new_addrset
);
if addrset != new_addrset {
debug!(
"Addr changed from {:?} to {:?} updating cache.",
addrset, new_addrset
);
self.store_in_cache(hostname, new_addrset);
}
}
Err(err) => {
error!(
"There was an error trying to resolv {}: ({}).",
hostname, err
);
}
}
}
debug!("Finished refreshing cached DNS addresses.");
sleep(interval).await;
}
}
/// Returns a `AddrSet` given the specified hostname.
///
/// This method first tries to fetch the value from the cache, if it misses
/// then it is resolved and stored in the cache. TTL from records is ignored.
///
/// # Arguments
///
/// * `host` - A string slice referencing the hostname to be resolved.
///
/// # Example:
///
/// ```
/// use pgcat::dns_cache::{CachedResolverConfig, CachedResolver};
///
/// # tokio_test::block_on(async {
/// let config = CachedResolverConfig::default();
/// let resolver = CachedResolver::new(config, None).await.unwrap();
/// let response = resolver.lookup_ip("www.google.com.");
/// # })
/// ```
///
pub async fn lookup_ip(&self, host: &str) -> ResolveResult<AddrSet> {
debug!("Lookup up {} in cache", host);
match self.fetch_from_cache(host) {
Some(addr_set) => {
debug!("Cache hit!");
Ok(addr_set)
}
None => {
debug!("Not found, executing a dns query!");
if let Some(ref resolver) = self.resolver {
let addr_set = AddrSet::from(resolver.lookup_ip(host).await?);
debug!("Obtained: {:?}", addr_set);
self.store_in_cache(host, addr_set.clone());
Ok(addr_set)
} else {
Err(ResolveError::from("No resolver available"))
}
}
}
}
//
// Returns true if the stored host resolution differs from the AddrSet passed.
pub fn has_changed(&self, host: &str, addr_set: &AddrSet) -> bool {
if let Some(fetched_addr_set) = self.fetch_from_cache(host) {
return fetched_addr_set != *addr_set;
}
false
}
// Fetches an AddrSet from the inner cache adquiring the read lock.
fn fetch_from_cache(&self, key: &str) -> Option<AddrSet> {
if let Some(ref hash) = self.data {
if let Some(addr_set) = hash.read().unwrap().get(key) {
return Some(addr_set.clone());
}
}
None
}
// Sets up the global CACHED_RESOLVER static variable so we can globally use DNS
// cache.
pub async fn from_config() -> Result<(), Error> {
let cached_resolver = CACHED_RESOLVER.load();
let desired_config = CachedResolverConfig::from(get_config());
if cached_resolver.config != desired_config {
if let Some(ref refresh_loop) = *(cached_resolver.refresh_loop.write().unwrap()) {
warn!("Killing Dnscache refresh loop as its configuration is being reloaded");
refresh_loop.abort()
}
let new_resolver = if let Some(ref data) = cached_resolver.data {
let data = Some(data.read().unwrap().clone());
CachedResolver::new(desired_config, data).await
} else {
CachedResolver::new(desired_config, None).await
};
match new_resolver {
Ok(ok) => {
CACHED_RESOLVER.store(ok);
Ok(())
}
Err(err) => {
let message = format!("Error setting up cached_resolver. Error: {:?}, will continue without this feature.", err);
Err(Error::DNSCachedError(message))
}
}
} else {
Ok(())
}
}
// Stores the AddrSet in cache adquiring the write lock.
fn store_in_cache(&self, host: &str, addr_set: AddrSet) {
if let Some(ref data) = self.data {
data.write().unwrap().insert(host.to_string(), addr_set);
} else {
error!("Could not insert, Hash not initialized");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use trust_dns_resolver::error::ResolveError;
#[tokio::test]
async fn new() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await;
assert!(resolver.is_ok());
}
#[tokio::test]
async fn lookup_ip() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let response = resolver.lookup_ip("www.google.com.").await;
assert!(response.is_ok());
}
#[tokio::test]
async fn has_changed() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "www.google.com.";
let response = resolver.lookup_ip(hostname).await;
let addr_set = response.unwrap();
assert!(!resolver.has_changed(hostname, &addr_set));
}
#[tokio::test]
async fn unknown_host() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "www.idontexists.";
let response = resolver.lookup_ip(hostname).await;
assert!(matches!(response, Err(ResolveError { .. })));
}
#[tokio::test]
async fn incorrect_address() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "w ww.idontexists.";
let response = resolver.lookup_ip(hostname).await;
assert!(matches!(response, Err(ResolveError { .. })));
assert!(!resolver.has_changed(hostname, &AddrSet::new()));
}
#[tokio::test]
// Ok, this test is based on the fact that google does DNS RR
// and does not responds with every available ip everytime, so
// if I cache here, it will miss after one cache iteration or two.
async fn thread() {
let config = CachedResolverConfig {
dns_max_ttl: 10,
enabled: true,
};
let resolver = CachedResolver::new(config, None).await.unwrap();
let hostname = "www.google.com.";
let response = resolver.lookup_ip(hostname).await;
let addr_set = response.unwrap();
assert!(!resolver.has_changed(hostname, &addr_set));
let resolver_for_refresher = resolver.clone();
let _thread_handle = tokio::task::spawn(async move {
resolver_for_refresher.refresh_dns_entries_loop().await;
});
assert!(!resolver.has_changed(hostname, &addr_set));
}
}

View File

@@ -19,6 +19,7 @@ pub enum Error {
ClientError(String), ClientError(String),
TlsError, TlsError,
StatementTimeout, StatementTimeout,
DNSCachedError(String),
ShuttingDown, ShuttingDown,
ParseBytesError(String), ParseBytesError(String),
AuthError(String), AuthError(String),

View File

@@ -1,6 +1,7 @@
pub mod auth_passthrough; pub mod auth_passthrough;
pub mod config; pub mod config;
pub mod constants; pub mod constants;
pub mod dns_cache;
pub mod errors; pub mod errors;
pub mod messages; pub mod messages;
pub mod mirrors; pub mod mirrors;

View File

@@ -36,6 +36,7 @@ extern crate sqlparser;
extern crate tokio; extern crate tokio;
extern crate tokio_rustls; extern crate tokio_rustls;
extern crate toml; extern crate toml;
extern crate trust_dns_resolver;
#[cfg(not(target_env = "msvc"))] #[cfg(not(target_env = "msvc"))]
use jemallocator::Jemalloc; use jemallocator::Jemalloc;
@@ -65,6 +66,7 @@ mod auth_passthrough;
mod client; mod client;
mod config; mod config;
mod constants; mod constants;
mod dns_cache;
mod errors; mod errors;
mod messages; mod messages;
mod mirrors; mod mirrors;
@@ -166,8 +168,14 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
// Statistics reporting. // Statistics reporting.
REPORTER.store(Arc::new(Reporter::default())); REPORTER.store(Arc::new(Reporter::default()));
// Connection pool that allows to query all shards and replicas. // Starts (if enabled) dns cache before pools initialization
match ConnectionPool::from_config(client_server_map.clone()).await { match dns_cache::CachedResolver::from_config().await {
Ok(_) => (),
Err(err) => error!("DNS cache initialization error: {:?}", err),
};
// Connection pool that allows to query all shards and replicas.
match ConnectionPool::from_config(client_server_map.clone()).await {
Ok(_) => (), Ok(_) => (),
Err(err) => { Err(err) => {
error!("Pool error: {:?}", err); error!("Pool error: {:?}", err);

View File

@@ -7,6 +7,7 @@ use parking_lot::{Mutex, RwLock};
use postgres_protocol::message; use postgres_protocol::message;
use std::collections::HashMap; use std::collections::HashMap;
use std::io::Read; use std::io::Read;
use std::net::IpAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::SystemTime; use std::time::SystemTime;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream}; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, BufStream};
@@ -16,6 +17,7 @@ use tokio_rustls::{client::TlsStream, TlsConnector};
use crate::config::{get_config, Address, User}; use crate::config::{get_config, Address, User};
use crate::constants::*; use crate::constants::*;
use crate::dns_cache::{AddrSet, CACHED_RESOLVER};
use crate::errors::{Error, ServerIdentifier}; use crate::errors::{Error, ServerIdentifier};
use crate::messages::*; use crate::messages::*;
use crate::mirrors::MirroringManager; use crate::mirrors::MirroringManager;
@@ -148,6 +150,9 @@ pub struct Server {
last_activity: SystemTime, last_activity: SystemTime,
mirror_manager: Option<MirroringManager>, mirror_manager: Option<MirroringManager>,
// Associated addresses used
addr_set: Option<AddrSet>,
} }
impl Server { impl Server {
@@ -161,6 +166,24 @@ impl Server {
stats: Arc<ServerStats>, stats: Arc<ServerStats>,
auth_hash: Arc<RwLock<Option<String>>>, auth_hash: Arc<RwLock<Option<String>>>,
) -> Result<Server, Error> { ) -> Result<Server, Error> {
let cached_resolver = CACHED_RESOLVER.load();
let mut addr_set: Option<AddrSet> = None;
// If we are caching addresses and hostname is not an IP
if cached_resolver.enabled() && address.host.parse::<IpAddr>().is_err() {
debug!("Resolving {}", &address.host);
addr_set = match cached_resolver.lookup_ip(&address.host).await {
Ok(ok) => {
debug!("Obtained: {:?}", ok);
Some(ok)
}
Err(err) => {
warn!("Error trying to resolve {}, ({:?})", &address.host, err);
None
}
}
};
let mut stream = let mut stream =
match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await { match TcpStream::connect(&format!("{}:{}", &address.host, address.port)).await {
Ok(stream) => stream, Ok(stream) => stream,
@@ -609,6 +632,7 @@ impl Server {
bad: false, bad: false,
needs_cleanup: false, needs_cleanup: false,
client_server_map, client_server_map,
addr_set,
connected_at: chrono::offset::Utc::now().naive_utc(), connected_at: chrono::offset::Utc::now().naive_utc(),
stats, stats,
application_name: String::new(), application_name: String::new(),
@@ -849,7 +873,23 @@ impl Server {
/// Server & client are out of sync, we must discard this connection. /// Server & client are out of sync, we must discard this connection.
/// This happens with clients that misbehave. /// This happens with clients that misbehave.
pub fn is_bad(&self) -> bool { pub fn is_bad(&self) -> bool {
self.bad if self.bad {
return self.bad;
};
let cached_resolver = CACHED_RESOLVER.load();
if cached_resolver.enabled() {
if let Some(addr_set) = &self.addr_set {
if cached_resolver.has_changed(self.address.host.as_str(), addr_set) {
warn!(
"DNS changed for {}, it was {:?}. Dropping server connection.",
self.address.host.as_str(),
addr_set
);
return true;
}
}
}
false
} }
/// Get server startup information to forward it to the client. /// Get server startup information to forward it to the client.