mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-25 10:06:28 +00:00
Compare commits
15 Commits
circleci_A
...
circleci_O
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f2ed12e8ce | ||
|
|
a68071dd28 | ||
|
|
c27d801abf | ||
|
|
186e72298f | ||
|
|
3935366d86 | ||
|
|
b575935b1d | ||
|
|
efbab1c333 | ||
|
|
9f12d7958e | ||
|
|
e6634ef461 | ||
|
|
dab2e58647 | ||
|
|
4aaa4378cf | ||
|
|
670311daf9 | ||
|
|
b9ec7f8036 | ||
|
|
d91d23848b | ||
|
|
bbbc01a467 |
4
.github/workflows/chart-lint-test.yaml
vendored
4
.github/workflows/chart-lint-test.yaml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
# Python is required because `ct lint` runs Yamale (https://github.com/23andMe/Yamale) and
|
||||
# yamllint (https://github.com/adrienverge/yamllint) which require Python
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4.1.0
|
||||
uses: actions/setup-python@v5.1.0
|
||||
with:
|
||||
python-version: 3.7
|
||||
|
||||
@@ -43,7 +43,7 @@ jobs:
|
||||
run: ct lint --config ct.yaml
|
||||
|
||||
- name: Create kind cluster
|
||||
uses: helm/kind-action@v1.7.0
|
||||
uses: helm/kind-action@v1.10.0
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
|
||||
2
.github/workflows/chart-release.yaml
vendored
2
.github/workflows/chart-release.yaml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
version: v3.13.0
|
||||
|
||||
- name: Run chart-releaser
|
||||
uses: helm/chart-releaser-action@be16258da8010256c6e82849661221415f031968 # v1.5.0
|
||||
uses: helm/chart-releaser-action@a917fd15b20e8b64b94d9158ad54cd6345335584 # v1.6.0
|
||||
with:
|
||||
charts_dir: charts
|
||||
config: cr.yaml
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,3 +12,4 @@ dev/cache
|
||||
!dev/cache/.keepme
|
||||
.venv
|
||||
**/__pycache__
|
||||
.bundle
|
||||
36
CONFIG.md
36
CONFIG.md
@@ -36,10 +36,11 @@ Port at which prometheus exporter listens on.
|
||||
### connect_timeout
|
||||
```
|
||||
path: general.connect_timeout
|
||||
default: 5000 # milliseconds
|
||||
default: 1000 # milliseconds
|
||||
```
|
||||
|
||||
How long to wait before aborting a server connection (ms).
|
||||
How long the client waits to obtain a server connection before aborting (ms).
|
||||
This is similar to PgBouncer's `query_wait_timeout`.
|
||||
|
||||
### idle_timeout
|
||||
```
|
||||
@@ -129,6 +130,16 @@ default: 60 # seconds
|
||||
|
||||
How long to ban a server if it fails a health check (seconds).
|
||||
|
||||
### unban_replicas_when_all_banned
|
||||
```
|
||||
path: general.unban_replicas_when_all_banned
|
||||
default: true
|
||||
```
|
||||
|
||||
Whether or not we should unban all replicas when they are all banned. This is set
|
||||
to true by default to prevent disconnection when we have replicas with a false positive
|
||||
health check.
|
||||
|
||||
### log_client_connections
|
||||
```
|
||||
path: general.log_client_connections
|
||||
@@ -462,10 +473,18 @@ path: pools.<pool_name>.users.<user_index>.pool_size
|
||||
default: 9
|
||||
```
|
||||
|
||||
Maximum number of server connections that can be established for this user
|
||||
Maximum number of server connections that can be established for this user.
|
||||
The maximum number of connection from a single Pgcat process to any database in the cluster
|
||||
is the sum of pool_size across all users.
|
||||
|
||||
### min_pool_size
|
||||
```
|
||||
path: pools.<pool_name>.users.<user_index>.min_pool_size
|
||||
default: 0
|
||||
```
|
||||
|
||||
Minimum number of idle server connections to retain for this pool.
|
||||
|
||||
### statement_timeout
|
||||
```
|
||||
path: pools.<pool_name>.users.<user_index>.statement_timeout
|
||||
@@ -475,6 +494,16 @@ default: 0
|
||||
Maximum query duration. Dangerous, but protects against DBs that died in a non-obvious way.
|
||||
0 means it is disabled.
|
||||
|
||||
### connect_timeout
|
||||
```
|
||||
path: pools.<pool_name>.users.<user_index>.connect_timeout
|
||||
default: <UNSET> # milliseconds
|
||||
```
|
||||
|
||||
How long the client waits to obtain a server connection before aborting (ms).
|
||||
This is similar to PgBouncer's `query_wait_timeout`.
|
||||
If unset, uses the `connect_timeout` defined globally.
|
||||
|
||||
## `pools.<pool_name>.shards.<shard_index>` Section
|
||||
|
||||
### servers
|
||||
@@ -502,4 +531,3 @@ default: "shard0"
|
||||
```
|
||||
|
||||
Database name (e.g. "postgres")
|
||||
|
||||
|
||||
5
Cargo.lock
generated
5
Cargo.lock
generated
@@ -192,12 +192,11 @@ checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d"
|
||||
|
||||
[[package]]
|
||||
name = "bb8"
|
||||
version = "0.8.1"
|
||||
version = "0.8.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "98b4b0f25f18bcdc3ac72bdb486ed0acf7e185221fd4dc985bc15db5800b0ba2"
|
||||
checksum = "d89aabfae550a5c44b43ab941844ffcd2e993cb6900b342debf59e9ea74acdb8"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"futures-channel",
|
||||
"futures-util",
|
||||
"parking_lot",
|
||||
"tokio",
|
||||
|
||||
@@ -8,7 +8,7 @@ edition = "2021"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
bytes = "1"
|
||||
md-5 = "0.10"
|
||||
bb8 = "0.8.1"
|
||||
bb8 = "=0.8.6"
|
||||
async-trait = "0.1"
|
||||
rand = "0.8"
|
||||
chrono = "0.4"
|
||||
|
||||
@@ -175,7 +175,7 @@ The setting will persist until it's changed again or the client disconnects.
|
||||
By default, all queries are routed to the first available server; `default_role` setting controls this behavior.
|
||||
|
||||
### Failover
|
||||
All servers are checked with a `;` (very fast) query before being given to a client. Additionally, the server health is monitored with every client query that it processes. If the server is not reachable, it will be banned and cannot serve any more transactions for the duration of the ban. The queries are routed to the remaining servers. If all servers become banned, the ban list is cleared: this is a safety precaution against false positives. The primary can never be banned.
|
||||
All servers are checked with a `;` (very fast) query before being given to a client. Additionally, the server health is monitored with every client query that it processes. If the server is not reachable, it will be banned and cannot serve any more transactions for the duration of the ban. The queries are routed to the remaining servers. If all servers become banned, the behavior is controlled by the configuration parameter `unban_replicas_when_all_banned`. If it is set to true (the default), the ban list is cleared: this is a safety precaution against false positives, if it is set to false, no replicas will be available until they become healthy. The primary can never be banned.
|
||||
|
||||
The ban time can be changed with `ban_time`. The default is 60 seconds.
|
||||
|
||||
|
||||
@@ -5,4 +5,4 @@ maintainers:
|
||||
- name: Wildcard
|
||||
email: support@w6d.io
|
||||
appVersion: "1.2.0"
|
||||
version: 0.2.0
|
||||
version: 0.2.1
|
||||
|
||||
@@ -15,6 +15,7 @@ stringData:
|
||||
connect_timeout = {{ .Values.configuration.general.connect_timeout }}
|
||||
idle_timeout = {{ .Values.configuration.general.idle_timeout | int }}
|
||||
server_lifetime = {{ .Values.configuration.general.server_lifetime | int }}
|
||||
server_tls = {{ .Values.configuration.general.server_tls }}
|
||||
idle_client_in_transaction_timeout = {{ .Values.configuration.general.idle_client_in_transaction_timeout | int }}
|
||||
healthcheck_timeout = {{ .Values.configuration.general.healthcheck_timeout }}
|
||||
healthcheck_delay = {{ .Values.configuration.general.healthcheck_delay }}
|
||||
@@ -58,11 +59,21 @@ stringData:
|
||||
##
|
||||
[pools.{{ $pool.name | quote }}.users.{{ $index }}]
|
||||
username = {{ $user.username | quote }}
|
||||
{{- if $user.password }}
|
||||
password = {{ $user.password | quote }}
|
||||
{{- else if and $user.passwordSecret.name $user.passwordSecret.key }}
|
||||
{{- $secret := (lookup "v1" "Secret" $.Release.Namespace $user.passwordSecret.name) }}
|
||||
{{- if $secret }}
|
||||
{{- $password := index $secret.data $user.passwordSecret.key | b64dec }}
|
||||
password = {{ $password | quote }}
|
||||
{{- end }}
|
||||
{{- end }}
|
||||
pool_size = {{ $user.pool_size }}
|
||||
statement_timeout = {{ $user.statement_timeout }}
|
||||
min_pool_size = 3
|
||||
server_lifetime = 60000
|
||||
statement_timeout = {{ default 0 $user.statement_timeout }}
|
||||
min_pool_size = {{ default 3 $user.min_pool_size }}
|
||||
{{- if $user.server_lifetime }}
|
||||
server_lifetime = {{ $user.server_lifetime }}
|
||||
{{- end }}
|
||||
{{- if and $user.server_username $user.server_password }}
|
||||
server_username = {{ $user.server_username | quote }}
|
||||
server_password = {{ $user.server_password | quote }}
|
||||
|
||||
@@ -175,6 +175,9 @@ configuration:
|
||||
# Max connection lifetime before it's closed, even if actively used.
|
||||
server_lifetime: 86400000 # 24 hours
|
||||
|
||||
# Whether to use TLS for server connections or not.
|
||||
server_tls: false
|
||||
|
||||
# How long a client is allowed to be idle while in a transaction (ms).
|
||||
idle_client_in_transaction_timeout: 0 # milliseconds
|
||||
|
||||
@@ -315,7 +318,9 @@ configuration:
|
||||
# ## Credentials for users that may connect to this cluster
|
||||
# ## @param users [array]
|
||||
# ## @param users[0].username Name of the env var (required)
|
||||
# ## @param users[0].password Value for the env var (required)
|
||||
# ## @param users[0].password Value for the env var (required) leave empty to use existing secret see passwordSecret.name and passwordSecret.key
|
||||
# ## @param users[0].passwordSecret.name Name of the secret containing the password
|
||||
# ## @param users[0].passwordSecret.key Key in the secret containing the password
|
||||
# ## @param users[0].pool_size Maximum number of server connections that can be established for this user
|
||||
# ## @param users[0].statement_timeout Maximum query duration. Dangerous, but protects against DBs that died in a non-obvious way.
|
||||
# users: []
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::config::AuthType;
|
||||
use crate::errors::Error;
|
||||
use crate::pool::ConnectionPool;
|
||||
use crate::server::Server;
|
||||
@@ -71,6 +72,7 @@ impl AuthPassthrough {
|
||||
pub async fn fetch_hash(&self, address: &crate::config::Address) -> Result<String, Error> {
|
||||
let auth_user = crate::config::User {
|
||||
username: self.user.clone(),
|
||||
auth_type: AuthType::MD5,
|
||||
password: Some(self.password.clone()),
|
||||
server_username: None,
|
||||
server_password: None,
|
||||
|
||||
317
src/client.rs
317
src/client.rs
@@ -14,7 +14,9 @@ use tokio::sync::mpsc::Sender;
|
||||
|
||||
use crate::admin::{generate_server_parameters_for_admin, handle_admin};
|
||||
use crate::auth_passthrough::refetch_auth_hash;
|
||||
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode};
|
||||
use crate::config::{
|
||||
get_config, get_idle_client_in_transaction_timeout, Address, AuthType, PoolMode,
|
||||
};
|
||||
use crate::constants::*;
|
||||
use crate::messages::*;
|
||||
use crate::plugins::PluginOutput;
|
||||
@@ -463,8 +465,8 @@ where
|
||||
.count()
|
||||
== 1;
|
||||
|
||||
// Kick any client that's not admin while we're in admin-only mode.
|
||||
if !admin && admin_only {
|
||||
// Kick any client that's not admin while we're in admin-only mode.
|
||||
debug!(
|
||||
"Rejecting non-admin connection to {} when in admin only mode",
|
||||
pool_name
|
||||
@@ -481,72 +483,76 @@ where
|
||||
let process_id: i32 = rand::random();
|
||||
let secret_key: i32 = rand::random();
|
||||
|
||||
// Perform MD5 authentication.
|
||||
// TODO: Add SASL support.
|
||||
let salt = md5_challenge(&mut write).await?;
|
||||
|
||||
let code = match read.read_u8().await {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password code".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// PasswordMessage
|
||||
if code as char != 'p' {
|
||||
return Err(Error::ProtocolSyncError(format!(
|
||||
"Expected p, got {}",
|
||||
code as char
|
||||
)));
|
||||
}
|
||||
|
||||
let len = match read.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password message length".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let mut password_response = vec![0u8; (len - 4) as usize];
|
||||
|
||||
match read.read_exact(&mut password_response).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password message".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let mut prepared_statements_enabled = false;
|
||||
|
||||
// Authenticate admin user.
|
||||
let (transaction_mode, mut server_parameters) = if admin {
|
||||
let config = get_config();
|
||||
// TODO: Add SASL support.
|
||||
// Perform MD5 authentication.
|
||||
match config.general.admin_auth_type {
|
||||
AuthType::Trust => (),
|
||||
AuthType::MD5 => {
|
||||
let salt = md5_challenge(&mut write).await?;
|
||||
|
||||
// Compare server and client hashes.
|
||||
let password_hash = md5_hash_password(
|
||||
&config.general.admin_username,
|
||||
&config.general.admin_password,
|
||||
&salt,
|
||||
);
|
||||
let code = match read.read_u8().await {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password code".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
if password_hash != password_response {
|
||||
let error = Error::ClientGeneralError("Invalid password".into(), client_identifier);
|
||||
// PasswordMessage
|
||||
if code as char != 'p' {
|
||||
return Err(Error::ProtocolSyncError(format!(
|
||||
"Expected p, got {}",
|
||||
code as char
|
||||
)));
|
||||
}
|
||||
|
||||
warn!("{}", error);
|
||||
wrong_password(&mut write, username).await?;
|
||||
let len = match read.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password message length".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
return Err(error);
|
||||
let mut password_response = vec![0u8; (len - 4) as usize];
|
||||
|
||||
match read.read_exact(&mut password_response).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password message".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
// Compare server and client hashes.
|
||||
let password_hash = md5_hash_password(
|
||||
&config.general.admin_username,
|
||||
&config.general.admin_password,
|
||||
&salt,
|
||||
);
|
||||
|
||||
if password_hash != password_response {
|
||||
let error =
|
||||
Error::ClientGeneralError("Invalid password".into(), client_identifier);
|
||||
|
||||
warn!("{}", error);
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(error);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(false, generate_server_parameters_for_admin())
|
||||
}
|
||||
// Authenticate normal user.
|
||||
@@ -573,92 +579,143 @@ where
|
||||
// Obtain the hash to compare, we give preference to that written in cleartext in config
|
||||
// if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained
|
||||
// when the pool was created. If there is no hash there, we try to fetch it one more time.
|
||||
let password_hash = if let Some(password) = &pool.settings.user.password {
|
||||
Some(md5_hash_password(username, password, &salt))
|
||||
} else {
|
||||
if !get_config().is_auth_query_configured() {
|
||||
wrong_password(&mut write, username).await?;
|
||||
return Err(Error::ClientAuthImpossible(username.into()));
|
||||
}
|
||||
match pool.settings.user.auth_type {
|
||||
AuthType::Trust => (),
|
||||
AuthType::MD5 => {
|
||||
// Perform MD5 authentication.
|
||||
// TODO: Add SASL support.
|
||||
let salt = md5_challenge(&mut write).await?;
|
||||
|
||||
let mut hash = (*pool.auth_hash.read()).clone();
|
||||
let code = match read.read_u8().await {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password code".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
if hash.is_none() {
|
||||
warn!(
|
||||
"Query auth configured \
|
||||
but no hash password found \
|
||||
for pool {}. Will try to refetch it.",
|
||||
pool_name
|
||||
);
|
||||
// PasswordMessage
|
||||
if code as char != 'p' {
|
||||
return Err(Error::ProtocolSyncError(format!(
|
||||
"Expected p, got {}",
|
||||
code as char
|
||||
)));
|
||||
}
|
||||
|
||||
match refetch_auth_hash(&pool).await {
|
||||
Ok(fetched_hash) => {
|
||||
warn!("Password for {}, obtained. Updating.", client_identifier);
|
||||
let len = match read.read_i32().await {
|
||||
Ok(len) => len,
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password message length".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let mut password_response = vec![0u8; (len - 4) as usize];
|
||||
|
||||
match read.read_exact(&mut password_response).await {
|
||||
Ok(_) => (),
|
||||
Err(_) => {
|
||||
return Err(Error::ClientSocketError(
|
||||
"password message".into(),
|
||||
client_identifier,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let password_hash = if let Some(password) = &pool.settings.user.password {
|
||||
Some(md5_hash_password(username, password, &salt))
|
||||
} else {
|
||||
if !get_config().is_auth_query_configured() {
|
||||
wrong_password(&mut write, username).await?;
|
||||
return Err(Error::ClientAuthImpossible(username.into()));
|
||||
}
|
||||
|
||||
let mut hash = (*pool.auth_hash.read()).clone();
|
||||
|
||||
if hash.is_none() {
|
||||
warn!(
|
||||
"Query auth configured \
|
||||
but no hash password found \
|
||||
for pool {}. Will try to refetch it.",
|
||||
pool_name
|
||||
);
|
||||
|
||||
match refetch_auth_hash(&pool).await {
|
||||
Ok(fetched_hash) => {
|
||||
warn!(
|
||||
"Password for {}, obtained. Updating.",
|
||||
client_identifier
|
||||
);
|
||||
|
||||
{
|
||||
let mut pool_auth_hash = pool.auth_hash.write();
|
||||
*pool_auth_hash = Some(fetched_hash.clone());
|
||||
}
|
||||
|
||||
hash = Some(fetched_hash);
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(Error::ClientAuthPassthroughError(
|
||||
err.to_string(),
|
||||
client_identifier,
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Some(md5_hash_second_pass(&hash.unwrap(), &salt))
|
||||
};
|
||||
|
||||
// Once we have the resulting hash, we compare with what the client gave us.
|
||||
// If they do not match and auth query is set up, we try to refetch the hash one more time
|
||||
// to see if the password has changed since the pool was created.
|
||||
//
|
||||
// @TODO: we could end up fetching again the same password twice (see above).
|
||||
if password_hash.unwrap() != password_response {
|
||||
warn!(
|
||||
"Invalid password {}, will try to refetch it.",
|
||||
client_identifier
|
||||
);
|
||||
|
||||
let fetched_hash = match refetch_auth_hash(&pool).await {
|
||||
Ok(fetched_hash) => fetched_hash,
|
||||
Err(err) => {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
|
||||
|
||||
// Ok password changed in server an auth is possible.
|
||||
if new_password_hash == password_response {
|
||||
warn!(
|
||||
"Password for {}, changed in server. Updating.",
|
||||
client_identifier
|
||||
);
|
||||
|
||||
{
|
||||
let mut pool_auth_hash = pool.auth_hash.write();
|
||||
*pool_auth_hash = Some(fetched_hash.clone());
|
||||
*pool_auth_hash = Some(fetched_hash);
|
||||
}
|
||||
|
||||
hash = Some(fetched_hash);
|
||||
}
|
||||
|
||||
Err(err) => {
|
||||
} else {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(Error::ClientAuthPassthroughError(
|
||||
err.to_string(),
|
||||
return Err(Error::ClientGeneralError(
|
||||
"Invalid password".into(),
|
||||
client_identifier,
|
||||
));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Some(md5_hash_second_pass(&hash.unwrap(), &salt))
|
||||
};
|
||||
|
||||
// Once we have the resulting hash, we compare with what the client gave us.
|
||||
// If they do not match and auth query is set up, we try to refetch the hash one more time
|
||||
// to see if the password has changed since the pool was created.
|
||||
//
|
||||
// @TODO: we could end up fetching again the same password twice (see above).
|
||||
if password_hash.unwrap() != password_response {
|
||||
warn!(
|
||||
"Invalid password {}, will try to refetch it.",
|
||||
client_identifier
|
||||
);
|
||||
|
||||
let fetched_hash = match refetch_auth_hash(&pool).await {
|
||||
Ok(fetched_hash) => fetched_hash,
|
||||
Err(err) => {
|
||||
wrong_password(&mut write, username).await?;
|
||||
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let new_password_hash = md5_hash_second_pass(&fetched_hash, &salt);
|
||||
|
||||
// Ok password changed in server an auth is possible.
|
||||
if new_password_hash == password_response {
|
||||
warn!(
|
||||
"Password for {}, changed in server. Updating.",
|
||||
client_identifier
|
||||
);
|
||||
|
||||
{
|
||||
let mut pool_auth_hash = pool.auth_hash.write();
|
||||
*pool_auth_hash = Some(fetched_hash);
|
||||
}
|
||||
} else {
|
||||
wrong_password(&mut write, username).await?;
|
||||
return Err(Error::ClientGeneralError(
|
||||
"Invalid password".into(),
|
||||
client_identifier,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
|
||||
prepared_statements_enabled =
|
||||
transaction_mode && pool.prepared_statement_cache.is_some();
|
||||
|
||||
@@ -208,6 +208,9 @@ impl Address {
|
||||
pub struct User {
|
||||
pub username: String,
|
||||
pub password: Option<String>,
|
||||
|
||||
#[serde(default = "User::default_auth_type")]
|
||||
pub auth_type: AuthType,
|
||||
pub server_username: Option<String>,
|
||||
pub server_password: Option<String>,
|
||||
pub pool_size: u32,
|
||||
@@ -225,6 +228,7 @@ impl Default for User {
|
||||
User {
|
||||
username: String::from("postgres"),
|
||||
password: None,
|
||||
auth_type: AuthType::MD5,
|
||||
server_username: None,
|
||||
server_password: None,
|
||||
pool_size: 15,
|
||||
@@ -239,6 +243,10 @@ impl Default for User {
|
||||
}
|
||||
|
||||
impl User {
|
||||
pub fn default_auth_type() -> AuthType {
|
||||
AuthType::MD5
|
||||
}
|
||||
|
||||
fn validate(&self) -> Result<(), Error> {
|
||||
if let Some(min_pool_size) = self.min_pool_size {
|
||||
if min_pool_size > self.pool_size {
|
||||
@@ -307,6 +315,9 @@ pub struct General {
|
||||
#[serde(default = "General::default_ban_time")]
|
||||
pub ban_time: i64,
|
||||
|
||||
#[serde(default)] // True
|
||||
pub unban_replicas_when_all_banned: bool,
|
||||
|
||||
#[serde(default = "General::default_idle_client_in_transaction_timeout")]
|
||||
pub idle_client_in_transaction_timeout: u64,
|
||||
|
||||
@@ -334,6 +345,9 @@ pub struct General {
|
||||
pub admin_username: String,
|
||||
pub admin_password: String,
|
||||
|
||||
#[serde(default = "General::default_admin_auth_type")]
|
||||
pub admin_auth_type: AuthType,
|
||||
|
||||
#[serde(default = "General::default_validate_config")]
|
||||
pub validate_config: bool,
|
||||
|
||||
@@ -348,6 +362,10 @@ impl General {
|
||||
"0.0.0.0".into()
|
||||
}
|
||||
|
||||
pub fn default_admin_auth_type() -> AuthType {
|
||||
AuthType::MD5
|
||||
}
|
||||
|
||||
pub fn default_port() -> u16 {
|
||||
5432
|
||||
}
|
||||
@@ -445,6 +463,7 @@ impl Default for General {
|
||||
healthcheck_timeout: Self::default_healthcheck_timeout(),
|
||||
healthcheck_delay: Self::default_healthcheck_delay(),
|
||||
ban_time: Self::default_ban_time(),
|
||||
unban_replicas_when_all_banned: true,
|
||||
idle_client_in_transaction_timeout: Self::default_idle_client_in_transaction_timeout(),
|
||||
server_lifetime: Self::default_server_lifetime(),
|
||||
server_round_robin: Self::default_server_round_robin(),
|
||||
@@ -456,6 +475,7 @@ impl Default for General {
|
||||
verify_server_certificate: false,
|
||||
admin_username: String::from("admin"),
|
||||
admin_password: String::from("admin"),
|
||||
admin_auth_type: AuthType::MD5,
|
||||
validate_config: true,
|
||||
auth_query: None,
|
||||
auth_query_user: None,
|
||||
@@ -476,6 +496,15 @@ pub enum PoolMode {
|
||||
Session,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Copy, Hash)]
|
||||
pub enum AuthType {
|
||||
#[serde(alias = "trust", alias = "Trust")]
|
||||
Trust,
|
||||
|
||||
#[serde(alias = "md5", alias = "MD5")]
|
||||
MD5,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PoolMode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
|
||||
10
src/pool.rs
10
src/pool.rs
@@ -189,6 +189,9 @@ pub struct PoolSettings {
|
||||
// Ban time
|
||||
pub ban_time: i64,
|
||||
|
||||
// Should we automatically unban replicas when all are banned?
|
||||
pub unban_replicas_when_all_banned: bool,
|
||||
|
||||
// Regex for searching for the sharding key in SQL statements
|
||||
pub sharding_key_regex: Option<Regex>,
|
||||
|
||||
@@ -228,6 +231,7 @@ impl Default for PoolSettings {
|
||||
healthcheck_delay: General::default_healthcheck_delay(),
|
||||
healthcheck_timeout: General::default_healthcheck_timeout(),
|
||||
ban_time: General::default_ban_time(),
|
||||
unban_replicas_when_all_banned: true,
|
||||
sharding_key_regex: None,
|
||||
shard_id_regex: None,
|
||||
regex_search_limit: 1000,
|
||||
@@ -541,6 +545,9 @@ impl ConnectionPool {
|
||||
healthcheck_delay: config.general.healthcheck_delay,
|
||||
healthcheck_timeout: config.general.healthcheck_timeout,
|
||||
ban_time: config.general.ban_time,
|
||||
unban_replicas_when_all_banned: config
|
||||
.general
|
||||
.unban_replicas_when_all_banned,
|
||||
sharding_key_regex: pool_config
|
||||
.sharding_key_regex
|
||||
.clone()
|
||||
@@ -946,8 +953,9 @@ impl ConnectionPool {
|
||||
let read_guard = self.banlist.read();
|
||||
let all_replicas_banned = read_guard[address.shard].len() == replicas_available;
|
||||
drop(read_guard);
|
||||
let unban_replicas_when_all_banned = self.settings.clone().unban_replicas_when_all_banned;
|
||||
|
||||
if all_replicas_banned {
|
||||
if all_replicas_banned && unban_replicas_when_all_banned {
|
||||
let mut write_guard = self.banlist.write();
|
||||
warn!("Unbanning all replicas.");
|
||||
write_guard[address.shard].clear();
|
||||
|
||||
@@ -309,6 +309,7 @@ async fn prometheus_stats(
|
||||
push_pool_stats(&mut lines);
|
||||
push_server_stats(&mut lines);
|
||||
push_database_stats(&mut lines);
|
||||
lines.push("".to_string()); // Ensure to end the stats with a line terminator as required by the specification.
|
||||
|
||||
Response::builder()
|
||||
.header("content-type", "text/plain; version=0.0.4")
|
||||
|
||||
@@ -386,6 +386,18 @@ impl QueryRouter {
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines if a query is a mutation or not.
|
||||
fn is_mutation_query(q: &sqlparser::ast::Query) -> bool {
|
||||
use sqlparser::ast::*;
|
||||
|
||||
match q.body.as_ref() {
|
||||
SetExpr::Insert(_) => true,
|
||||
SetExpr::Update(_) => true,
|
||||
SetExpr::Query(q) => Self::is_mutation_query(q),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to infer which server to connect to based on the contents of the query.
|
||||
pub fn infer(&mut self, ast: &Vec<sqlparser::ast::Statement>) -> Result<(), Error> {
|
||||
if !self.pool_settings.query_parser_read_write_splitting {
|
||||
@@ -428,8 +440,9 @@ impl QueryRouter {
|
||||
};
|
||||
|
||||
let has_locks = !query.locks.is_empty();
|
||||
let has_mutation = Self::is_mutation_query(query);
|
||||
|
||||
if has_locks {
|
||||
if has_locks || has_mutation {
|
||||
self.active_role = Some(Role::Primary);
|
||||
} else if !visited_write_statement {
|
||||
// If we already visited a write statement, we should be going to the primary.
|
||||
@@ -1113,6 +1126,26 @@ mod test {
|
||||
assert_eq!(qr.role(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_cte_queries() {
|
||||
QueryRouter::setup();
|
||||
let mut qr = QueryRouter::new();
|
||||
qr.pool_settings.query_parser_read_write_splitting = true;
|
||||
qr.pool_settings.query_parser_enabled = true;
|
||||
|
||||
let query = simple_query(
|
||||
"WITH t AS (
|
||||
SELECT id FROM users WHERE name ILIKE '%ja%'
|
||||
)
|
||||
UPDATE user_languages
|
||||
SET settings = '{}'
|
||||
FROM t WHERE t.id = user_id;",
|
||||
);
|
||||
let ast = qr.parse(&query).unwrap();
|
||||
assert!(qr.infer(&ast).is_ok());
|
||||
assert_eq!(qr.role(), Some(Role::Primary));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_infer_replica() {
|
||||
QueryRouter::setup();
|
||||
@@ -1431,6 +1464,7 @@ mod test {
|
||||
healthcheck_delay: PoolSettings::default().healthcheck_delay,
|
||||
healthcheck_timeout: PoolSettings::default().healthcheck_timeout,
|
||||
ban_time: PoolSettings::default().ban_time,
|
||||
unban_replicas_when_all_banned: true,
|
||||
sharding_key_regex: None,
|
||||
shard_id_regex: None,
|
||||
default_shard: crate::config::DefaultShard::Shard(0),
|
||||
@@ -1509,6 +1543,7 @@ mod test {
|
||||
healthcheck_delay: PoolSettings::default().healthcheck_delay,
|
||||
healthcheck_timeout: PoolSettings::default().healthcheck_timeout,
|
||||
ban_time: PoolSettings::default().ban_time,
|
||||
unban_replicas_when_all_banned: true,
|
||||
sharding_key_regex: Some(Regex::new(r"/\* sharding_key: (\d+) \*/").unwrap()),
|
||||
shard_id_regex: Some(Regex::new(r"/\* shard_id: (\d+) \*/").unwrap()),
|
||||
default_shard: crate::config::DefaultShard::Shard(0),
|
||||
|
||||
71
tests/python/test_auth.py
Normal file
71
tests/python/test_auth.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import utils
|
||||
import signal
|
||||
|
||||
class TestTrustAuth:
|
||||
@classmethod
|
||||
def setup_method(cls):
|
||||
config= """
|
||||
[general]
|
||||
host = "0.0.0.0"
|
||||
port = 6432
|
||||
admin_username = "admin_user"
|
||||
admin_password = ""
|
||||
admin_auth_type = "trust"
|
||||
|
||||
[pools.sharded_db.users.0]
|
||||
username = "sharding_user"
|
||||
password = "sharding_user"
|
||||
auth_type = "trust"
|
||||
pool_size = 10
|
||||
min_pool_size = 1
|
||||
pool_mode = "transaction"
|
||||
|
||||
[pools.sharded_db.shards.0]
|
||||
servers = [
|
||||
[ "127.0.0.1", 5432, "primary" ],
|
||||
]
|
||||
database = "shard0"
|
||||
"""
|
||||
utils.pgcat_generic_start(config)
|
||||
|
||||
@classmethod
|
||||
def teardown_method(self):
|
||||
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||
|
||||
def test_admin_trust_auth(self):
|
||||
conn, cur = utils.connect_db_trust(admin=True)
|
||||
cur.execute("SHOW POOLS")
|
||||
res = cur.fetchall()
|
||||
print(res)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
|
||||
def test_normal_trust_auth(self):
|
||||
conn, cur = utils.connect_db_trust(autocommit=False)
|
||||
cur.execute("SELECT 1")
|
||||
res = cur.fetchall()
|
||||
print(res)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
|
||||
class TestMD5Auth:
|
||||
@classmethod
|
||||
def setup_method(cls):
|
||||
utils.pgcat_start()
|
||||
|
||||
@classmethod
|
||||
def teardown_method(self):
|
||||
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||
|
||||
def test_normal_db_access(self):
|
||||
conn, cur = utils.connect_db(autocommit=False)
|
||||
cur.execute("SELECT 1")
|
||||
res = cur.fetchall()
|
||||
print(res)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
|
||||
def test_admin_db_access(self):
|
||||
conn, cur = utils.connect_db(admin=True)
|
||||
|
||||
cur.execute("SHOW POOLS")
|
||||
res = cur.fetchall()
|
||||
print(res)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
@@ -1,30 +1,12 @@
|
||||
import os
|
||||
|
||||
import signal
|
||||
import time
|
||||
|
||||
import psycopg2
|
||||
|
||||
import utils
|
||||
|
||||
SHUTDOWN_TIMEOUT = 5
|
||||
|
||||
def test_normal_db_access():
|
||||
utils.pgcat_start()
|
||||
conn, cur = utils.connect_db(autocommit=False)
|
||||
cur.execute("SELECT 1")
|
||||
res = cur.fetchall()
|
||||
print(res)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
|
||||
|
||||
def test_admin_db_access():
|
||||
conn, cur = utils.connect_db(admin=True)
|
||||
|
||||
cur.execute("SHOW POOLS")
|
||||
res = cur.fetchall()
|
||||
print(res)
|
||||
utils.cleanup_conn(conn, cur)
|
||||
|
||||
|
||||
def test_shutdown_logic():
|
||||
|
||||
@@ -256,3 +238,5 @@ def test_shutdown_logic():
|
||||
|
||||
utils.cleanup_conn(conn, cur)
|
||||
utils.pg_cat_send_signal(signal.SIGTERM)
|
||||
|
||||
# - - - - - - - - - - - - - - - - - -
|
||||
|
||||
@@ -1,20 +1,49 @@
|
||||
from typing import Tuple
|
||||
import os
|
||||
import psutil
|
||||
import signal
|
||||
import time
|
||||
from typing import Tuple
|
||||
import tempfile
|
||||
|
||||
import psutil
|
||||
import psycopg2
|
||||
|
||||
PGCAT_HOST = "127.0.0.1"
|
||||
PGCAT_PORT = "6432"
|
||||
|
||||
def pgcat_start():
|
||||
|
||||
def _pgcat_start(config_path: str):
|
||||
pg_cat_send_signal(signal.SIGTERM)
|
||||
os.system("./target/debug/pgcat .circleci/pgcat.toml &")
|
||||
os.system(f"./target/debug/pgcat {config_path} &")
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
def pgcat_start():
|
||||
_pgcat_start(config_path='.circleci/pgcat.toml')
|
||||
|
||||
|
||||
def pgcat_generic_start(config: str):
|
||||
tmp = tempfile.NamedTemporaryFile()
|
||||
with open(tmp.name, 'w') as f:
|
||||
f.write(config)
|
||||
_pgcat_start(config_path=tmp.name)
|
||||
|
||||
|
||||
def glauth_send_signal(signal: signal.Signals):
|
||||
try:
|
||||
for proc in psutil.process_iter(["pid", "name"]):
|
||||
if proc.name() == "glauth":
|
||||
os.kill(proc.pid, signal)
|
||||
except Exception as e:
|
||||
# The process can be gone when we send this signal
|
||||
print(e)
|
||||
|
||||
if signal == signal.SIGTERM:
|
||||
# Returns 0 if pgcat process exists
|
||||
time.sleep(2)
|
||||
if not os.system('pgrep glauth'):
|
||||
raise Exception("glauth not closed after SIGTERM")
|
||||
|
||||
|
||||
def pg_cat_send_signal(signal: signal.Signals):
|
||||
try:
|
||||
for proc in psutil.process_iter(["pid", "name"]):
|
||||
@@ -54,6 +83,27 @@ def connect_db(
|
||||
|
||||
return (conn, cur)
|
||||
|
||||
def connect_db_trust(
|
||||
autocommit: bool = True,
|
||||
admin: bool = False,
|
||||
) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]:
|
||||
|
||||
if admin:
|
||||
user = "admin_user"
|
||||
db = "pgcat"
|
||||
else:
|
||||
user = "sharding_user"
|
||||
db = "sharded_db"
|
||||
|
||||
conn = psycopg2.connect(
|
||||
f"postgres://{user}@{PGCAT_HOST}:{PGCAT_PORT}/{db}?application_name=testing_pgcat",
|
||||
connect_timeout=2,
|
||||
)
|
||||
conn.autocommit = autocommit
|
||||
cur = conn.cursor()
|
||||
|
||||
return (conn, cur)
|
||||
|
||||
|
||||
def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.cursor):
|
||||
cur.close()
|
||||
|
||||
@@ -1,22 +1,33 @@
|
||||
GEM
|
||||
remote: https://rubygems.org/
|
||||
specs:
|
||||
activemodel (7.0.4.1)
|
||||
activesupport (= 7.0.4.1)
|
||||
activerecord (7.0.4.1)
|
||||
activemodel (= 7.0.4.1)
|
||||
activesupport (= 7.0.4.1)
|
||||
activesupport (7.0.4.1)
|
||||
activemodel (7.1.4)
|
||||
activesupport (= 7.1.4)
|
||||
activerecord (7.1.4)
|
||||
activemodel (= 7.1.4)
|
||||
activesupport (= 7.1.4)
|
||||
timeout (>= 0.4.0)
|
||||
activesupport (7.1.4)
|
||||
base64
|
||||
bigdecimal
|
||||
concurrent-ruby (~> 1.0, >= 1.0.2)
|
||||
connection_pool (>= 2.2.5)
|
||||
drb
|
||||
i18n (>= 1.6, < 2)
|
||||
minitest (>= 5.1)
|
||||
mutex_m
|
||||
tzinfo (~> 2.0)
|
||||
ast (2.4.2)
|
||||
concurrent-ruby (1.1.10)
|
||||
base64 (0.2.0)
|
||||
bigdecimal (3.1.8)
|
||||
concurrent-ruby (1.3.4)
|
||||
connection_pool (2.4.1)
|
||||
diff-lcs (1.5.0)
|
||||
i18n (1.12.0)
|
||||
drb (2.2.1)
|
||||
i18n (1.14.5)
|
||||
concurrent-ruby (~> 1.0)
|
||||
minitest (5.17.0)
|
||||
minitest (5.25.1)
|
||||
mutex_m (0.2.0)
|
||||
parallel (1.22.1)
|
||||
parser (3.1.2.0)
|
||||
ast (~> 2.4.1)
|
||||
@@ -24,7 +35,8 @@ GEM
|
||||
pg (1.3.2)
|
||||
rainbow (3.1.1)
|
||||
regexp_parser (2.3.1)
|
||||
rexml (3.2.5)
|
||||
rexml (3.3.6)
|
||||
strscan
|
||||
rspec (3.11.0)
|
||||
rspec-core (~> 3.11.0)
|
||||
rspec-expectations (~> 3.11.0)
|
||||
@@ -50,10 +62,12 @@ GEM
|
||||
rubocop-ast (1.17.0)
|
||||
parser (>= 3.1.1.0)
|
||||
ruby-progressbar (1.11.0)
|
||||
strscan (3.1.0)
|
||||
timeout (0.4.1)
|
||||
toml (0.3.0)
|
||||
parslet (>= 1.8.0, < 3.0.0)
|
||||
toxiproxy (2.0.1)
|
||||
tzinfo (2.0.5)
|
||||
tzinfo (2.0.6)
|
||||
concurrent-ruby (~> 1.0)
|
||||
unicode-display_width (2.1.0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user