Compare commits

..

3 Commits

Author SHA1 Message Date
dependabot[bot]
d6e11e11da chore(deps): bump sqlparser from 0.41.0 to 0.51.0
Bumps [sqlparser](https://github.com/sqlparser-rs/sqlparser-rs) from 0.41.0 to 0.51.0.
- [Changelog](https://github.com/sqlparser-rs/sqlparser-rs/blob/main/CHANGELOG.md)
- [Commits](https://github.com/sqlparser-rs/sqlparser-rs/compare/v0.41.0...v0.51.0)

---
updated-dependencies:
- dependency-name: sqlparser
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-09-13 19:30:23 -05:00
dependabot[bot]
4aaa4378cf chore(deps): bump rexml from 3.2.8 to 3.3.6 in /tests/ruby (#803)
Bumps [rexml](https://github.com/ruby/rexml) from 3.2.8 to 3.3.6.
- [Release notes](https://github.com/ruby/rexml/releases)
- [Changelog](https://github.com/ruby/rexml/blob/master/NEWS.md)
- [Commits](https://github.com/ruby/rexml/compare/v3.2.8...v3.3.6)

---
updated-dependencies:
- dependency-name: rexml
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2024-09-13 19:19:30 -05:00
Andrew Jackson
670311daf9 Implement Trust Authentication (#805)
* Implement Trust Authentication

* Remove remaining LDAP stuff

* Reverted LDAP changes, Cleaned up tests

---------

Co-authored-by: Andrew Jackson <andrewjackson2988@gmail.com>
Co-authored-by: CommanderKeynes <andrewjackson947@gmail.coma>
2024-09-10 09:29:45 -05:00
11 changed files with 348 additions and 159 deletions

View File

@@ -8,7 +8,7 @@ jobs:
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
uses: docker/login-action@v3 uses: docker/login-action@v1
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.actor }} username: ${{ github.actor }}

4
Cargo.lock generated
View File

@@ -1526,9 +1526,9 @@ checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
[[package]] [[package]]
name = "sqlparser" name = "sqlparser"
version = "0.41.0" version = "0.51.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cc2c25a6c66789625ef164b4c7d2e548d627902280c13710d33da8222169964" checksum = "5fe11944a61da0da3f592e19a45ebe5ab92dc14a779907ff1f08fbb797bfefc7"
dependencies = [ dependencies = [
"log", "log",
"sqlparser_derive", "sqlparser_derive",

View File

@@ -19,7 +19,7 @@ serde_derive = "1"
regex = "1" regex = "1"
num_cpus = "1" num_cpus = "1"
once_cell = "1" once_cell = "1"
sqlparser = { version = "0.41", features = ["visitor"] } sqlparser = { version = "0.51", features = ["visitor"] }
log = "0.4" log = "0.4"
arc-swap = "1" arc-swap = "1"
parking_lot = "0.12.1" parking_lot = "0.12.1"

View File

@@ -1,3 +1,4 @@
use crate::config::AuthType;
use crate::errors::Error; use crate::errors::Error;
use crate::pool::ConnectionPool; use crate::pool::ConnectionPool;
use crate::server::Server; use crate::server::Server;
@@ -71,6 +72,7 @@ impl AuthPassthrough {
pub async fn fetch_hash(&self, address: &crate::config::Address) -> Result<String, Error> { pub async fn fetch_hash(&self, address: &crate::config::Address) -> Result<String, Error> {
let auth_user = crate::config::User { let auth_user = crate::config::User {
username: self.user.clone(), username: self.user.clone(),
auth_type: AuthType::MD5,
password: Some(self.password.clone()), password: Some(self.password.clone()),
server_username: None, server_username: None,
server_password: None, server_password: None,

View File

@@ -14,7 +14,9 @@ use tokio::sync::mpsc::Sender;
use crate::admin::{generate_server_parameters_for_admin, handle_admin}; use crate::admin::{generate_server_parameters_for_admin, handle_admin};
use crate::auth_passthrough::refetch_auth_hash; use crate::auth_passthrough::refetch_auth_hash;
use crate::config::{get_config, get_idle_client_in_transaction_timeout, Address, PoolMode}; use crate::config::{
get_config, get_idle_client_in_transaction_timeout, Address, AuthType, PoolMode,
};
use crate::constants::*; use crate::constants::*;
use crate::messages::*; use crate::messages::*;
use crate::plugins::PluginOutput; use crate::plugins::PluginOutput;
@@ -463,8 +465,8 @@ where
.count() .count()
== 1; == 1;
// Kick any client that's not admin while we're in admin-only mode.
if !admin && admin_only { if !admin && admin_only {
// Kick any client that's not admin while we're in admin-only mode.
debug!( debug!(
"Rejecting non-admin connection to {} when in admin only mode", "Rejecting non-admin connection to {} when in admin only mode",
pool_name pool_name
@@ -481,6 +483,105 @@ where
let process_id: i32 = rand::random(); let process_id: i32 = rand::random();
let secret_key: i32 = rand::random(); let secret_key: i32 = rand::random();
let mut prepared_statements_enabled = false;
// Authenticate admin user.
let (transaction_mode, mut server_parameters) = if admin {
let config = get_config();
// TODO: Add SASL support.
// Perform MD5 authentication.
match config.general.admin_auth_type {
AuthType::Trust => (),
AuthType::MD5 => {
let salt = md5_challenge(&mut write).await?;
let code = match read.read_u8().await {
Ok(p) => p,
Err(_) => {
return Err(Error::ClientSocketError(
"password code".into(),
client_identifier,
))
}
};
// PasswordMessage
if code as char != 'p' {
return Err(Error::ProtocolSyncError(format!(
"Expected p, got {}",
code as char
)));
}
let len = match read.read_i32().await {
Ok(len) => len,
Err(_) => {
return Err(Error::ClientSocketError(
"password message length".into(),
client_identifier,
))
}
};
let mut password_response = vec![0u8; (len - 4) as usize];
match read.read_exact(&mut password_response).await {
Ok(_) => (),
Err(_) => {
return Err(Error::ClientSocketError(
"password message".into(),
client_identifier,
))
}
};
// Compare server and client hashes.
let password_hash = md5_hash_password(
&config.general.admin_username,
&config.general.admin_password,
&salt,
);
if password_hash != password_response {
let error =
Error::ClientGeneralError("Invalid password".into(), client_identifier);
warn!("{}", error);
wrong_password(&mut write, username).await?;
return Err(error);
}
}
}
(false, generate_server_parameters_for_admin())
}
// Authenticate normal user.
else {
let pool = match get_pool(pool_name, username) {
Some(pool) => pool,
None => {
error_response(
&mut write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
pool_name, username
),
)
.await?;
return Err(Error::ClientGeneralError(
"Invalid pool name".into(),
client_identifier,
));
}
};
// Obtain the hash to compare, we give preference to that written in cleartext in config
// if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained
// when the pool was created. If there is no hash there, we try to fetch it one more time.
match pool.settings.user.auth_type {
AuthType::Trust => (),
AuthType::MD5 => {
// Perform MD5 authentication. // Perform MD5 authentication.
// TODO: Add SASL support. // TODO: Add SASL support.
let salt = md5_challenge(&mut write).await?; let salt = md5_challenge(&mut write).await?;
@@ -525,54 +626,6 @@ where
} }
}; };
let mut prepared_statements_enabled = false;
// Authenticate admin user.
let (transaction_mode, mut server_parameters) = if admin {
let config = get_config();
// Compare server and client hashes.
let password_hash = md5_hash_password(
&config.general.admin_username,
&config.general.admin_password,
&salt,
);
if password_hash != password_response {
let error = Error::ClientGeneralError("Invalid password".into(), client_identifier);
warn!("{}", error);
wrong_password(&mut write, username).await?;
return Err(error);
}
(false, generate_server_parameters_for_admin())
}
// Authenticate normal user.
else {
let pool = match get_pool(pool_name, username) {
Some(pool) => pool,
None => {
error_response(
&mut write,
&format!(
"No pool configured for database: {:?}, user: {:?}",
pool_name, username
),
)
.await?;
return Err(Error::ClientGeneralError(
"Invalid pool name".into(),
client_identifier,
));
}
};
// Obtain the hash to compare, we give preference to that written in cleartext in config
// if there is nothing set in cleartext and auth passthrough (auth_query) is configured, we use the hash obtained
// when the pool was created. If there is no hash there, we try to fetch it one more time.
let password_hash = if let Some(password) = &pool.settings.user.password { let password_hash = if let Some(password) = &pool.settings.user.password {
Some(md5_hash_password(username, password, &salt)) Some(md5_hash_password(username, password, &salt))
} else { } else {
@@ -593,7 +646,10 @@ where
match refetch_auth_hash(&pool).await { match refetch_auth_hash(&pool).await {
Ok(fetched_hash) => { Ok(fetched_hash) => {
warn!("Password for {}, obtained. Updating.", client_identifier); warn!(
"Password for {}, obtained. Updating.",
client_identifier
);
{ {
let mut pool_auth_hash = pool.auth_hash.write(); let mut pool_auth_hash = pool.auth_hash.write();
@@ -658,7 +714,8 @@ where
)); ));
} }
} }
}
}
let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction; let transaction_mode = pool.settings.pool_mode == PoolMode::Transaction;
prepared_statements_enabled = prepared_statements_enabled =
transaction_mode && pool.prepared_statement_cache.is_some(); transaction_mode && pool.prepared_statement_cache.is_some();

View File

@@ -208,6 +208,9 @@ impl Address {
pub struct User { pub struct User {
pub username: String, pub username: String,
pub password: Option<String>, pub password: Option<String>,
#[serde(default = "User::default_auth_type")]
pub auth_type: AuthType,
pub server_username: Option<String>, pub server_username: Option<String>,
pub server_password: Option<String>, pub server_password: Option<String>,
pub pool_size: u32, pub pool_size: u32,
@@ -225,6 +228,7 @@ impl Default for User {
User { User {
username: String::from("postgres"), username: String::from("postgres"),
password: None, password: None,
auth_type: AuthType::MD5,
server_username: None, server_username: None,
server_password: None, server_password: None,
pool_size: 15, pool_size: 15,
@@ -239,6 +243,10 @@ impl Default for User {
} }
impl User { impl User {
pub fn default_auth_type() -> AuthType {
AuthType::MD5
}
fn validate(&self) -> Result<(), Error> { fn validate(&self) -> Result<(), Error> {
if let Some(min_pool_size) = self.min_pool_size { if let Some(min_pool_size) = self.min_pool_size {
if min_pool_size > self.pool_size { if min_pool_size > self.pool_size {
@@ -334,6 +342,9 @@ pub struct General {
pub admin_username: String, pub admin_username: String,
pub admin_password: String, pub admin_password: String,
#[serde(default = "General::default_admin_auth_type")]
pub admin_auth_type: AuthType,
#[serde(default = "General::default_validate_config")] #[serde(default = "General::default_validate_config")]
pub validate_config: bool, pub validate_config: bool,
@@ -348,6 +359,10 @@ impl General {
"0.0.0.0".into() "0.0.0.0".into()
} }
pub fn default_admin_auth_type() -> AuthType {
AuthType::MD5
}
pub fn default_port() -> u16 { pub fn default_port() -> u16 {
5432 5432
} }
@@ -456,6 +471,7 @@ impl Default for General {
verify_server_certificate: false, verify_server_certificate: false,
admin_username: String::from("admin"), admin_username: String::from("admin"),
admin_password: String::from("admin"), admin_password: String::from("admin"),
admin_auth_type: AuthType::MD5,
validate_config: true, validate_config: true,
auth_query: None, auth_query: None,
auth_query_user: None, auth_query_user: None,
@@ -476,6 +492,15 @@ pub enum PoolMode {
Session, Session,
} }
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Copy, Hash)]
pub enum AuthType {
#[serde(alias = "trust", alias = "Trust")]
Trust,
#[serde(alias = "md5", alias = "MD5")]
MD5,
}
impl std::fmt::Display for PoolMode { impl std::fmt::Display for PoolMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {

71
tests/python/test_auth.py Normal file
View File

@@ -0,0 +1,71 @@
import utils
import signal
class TestTrustAuth:
@classmethod
def setup_method(cls):
config= """
[general]
host = "0.0.0.0"
port = 6432
admin_username = "admin_user"
admin_password = ""
admin_auth_type = "trust"
[pools.sharded_db.users.0]
username = "sharding_user"
password = "sharding_user"
auth_type = "trust"
pool_size = 10
min_pool_size = 1
pool_mode = "transaction"
[pools.sharded_db.shards.0]
servers = [
[ "127.0.0.1", 5432, "primary" ],
]
database = "shard0"
"""
utils.pgcat_generic_start(config)
@classmethod
def teardown_method(self):
utils.pg_cat_send_signal(signal.SIGTERM)
def test_admin_trust_auth(self):
conn, cur = utils.connect_db_trust(admin=True)
cur.execute("SHOW POOLS")
res = cur.fetchall()
print(res)
utils.cleanup_conn(conn, cur)
def test_normal_trust_auth(self):
conn, cur = utils.connect_db_trust(autocommit=False)
cur.execute("SELECT 1")
res = cur.fetchall()
print(res)
utils.cleanup_conn(conn, cur)
class TestMD5Auth:
@classmethod
def setup_method(cls):
utils.pgcat_start()
@classmethod
def teardown_method(self):
utils.pg_cat_send_signal(signal.SIGTERM)
def test_normal_db_access(self):
conn, cur = utils.connect_db(autocommit=False)
cur.execute("SELECT 1")
res = cur.fetchall()
print(res)
utils.cleanup_conn(conn, cur)
def test_admin_db_access(self):
conn, cur = utils.connect_db(admin=True)
cur.execute("SHOW POOLS")
res = cur.fetchall()
print(res)
utils.cleanup_conn(conn, cur)

View File

@@ -1,30 +1,12 @@
import os
import signal import signal
import time import time
import psycopg2 import psycopg2
import utils import utils
SHUTDOWN_TIMEOUT = 5 SHUTDOWN_TIMEOUT = 5
def test_normal_db_access():
utils.pgcat_start()
conn, cur = utils.connect_db(autocommit=False)
cur.execute("SELECT 1")
res = cur.fetchall()
print(res)
utils.cleanup_conn(conn, cur)
def test_admin_db_access():
conn, cur = utils.connect_db(admin=True)
cur.execute("SHOW POOLS")
res = cur.fetchall()
print(res)
utils.cleanup_conn(conn, cur)
def test_shutdown_logic(): def test_shutdown_logic():
@@ -256,3 +238,5 @@ def test_shutdown_logic():
utils.cleanup_conn(conn, cur) utils.cleanup_conn(conn, cur)
utils.pg_cat_send_signal(signal.SIGTERM) utils.pg_cat_send_signal(signal.SIGTERM)
# - - - - - - - - - - - - - - - - - -

View File

@@ -1,20 +1,49 @@
from typing import Tuple
import os import os
import psutil
import signal import signal
import time import time
from typing import Tuple
import tempfile
import psutil
import psycopg2 import psycopg2
PGCAT_HOST = "127.0.0.1" PGCAT_HOST = "127.0.0.1"
PGCAT_PORT = "6432" PGCAT_PORT = "6432"
def pgcat_start():
def _pgcat_start(config_path: str):
pg_cat_send_signal(signal.SIGTERM) pg_cat_send_signal(signal.SIGTERM)
os.system("./target/debug/pgcat .circleci/pgcat.toml &") os.system(f"./target/debug/pgcat {config_path} &")
time.sleep(2) time.sleep(2)
def pgcat_start():
_pgcat_start(config_path='.circleci/pgcat.toml')
def pgcat_generic_start(config: str):
tmp = tempfile.NamedTemporaryFile()
with open(tmp.name, 'w') as f:
f.write(config)
_pgcat_start(config_path=tmp.name)
def glauth_send_signal(signal: signal.Signals):
try:
for proc in psutil.process_iter(["pid", "name"]):
if proc.name() == "glauth":
os.kill(proc.pid, signal)
except Exception as e:
# The process can be gone when we send this signal
print(e)
if signal == signal.SIGTERM:
# Returns 0 if pgcat process exists
time.sleep(2)
if not os.system('pgrep glauth'):
raise Exception("glauth not closed after SIGTERM")
def pg_cat_send_signal(signal: signal.Signals): def pg_cat_send_signal(signal: signal.Signals):
try: try:
for proc in psutil.process_iter(["pid", "name"]): for proc in psutil.process_iter(["pid", "name"]):
@@ -54,6 +83,27 @@ def connect_db(
return (conn, cur) return (conn, cur)
def connect_db_trust(
autocommit: bool = True,
admin: bool = False,
) -> Tuple[psycopg2.extensions.connection, psycopg2.extensions.cursor]:
if admin:
user = "admin_user"
db = "pgcat"
else:
user = "sharding_user"
db = "sharded_db"
conn = psycopg2.connect(
f"postgres://{user}@{PGCAT_HOST}:{PGCAT_PORT}/{db}?application_name=testing_pgcat",
connect_timeout=2,
)
conn.autocommit = autocommit
cur = conn.cursor()
return (conn, cur)
def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.cursor): def cleanup_conn(conn: psycopg2.extensions.connection, cur: psycopg2.extensions.cursor):
cur.close() cur.close()

View File

@@ -24,8 +24,8 @@ GEM
pg (1.3.2) pg (1.3.2)
rainbow (3.1.1) rainbow (3.1.1)
regexp_parser (2.3.1) regexp_parser (2.3.1)
rexml (3.2.8) rexml (3.3.6)
strscan (>= 3.0.9) strscan
rspec (3.11.0) rspec (3.11.0)
rspec-core (~> 3.11.0) rspec-core (~> 3.11.0)
rspec-expectations (~> 3.11.0) rspec-expectations (~> 3.11.0)