diff --git a/.circleci/config.yml b/.circleci/config.yml index 09fc893..e9bb12f 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -32,9 +32,9 @@ jobs: command: "cargo fmt --check" - run: name: "Install dependencies" - command: "sudo apt-get update && sudo apt-get install -y psmisc postgresql-contrib-12 postgresql-client-12 ruby ruby-dev libpq-dev" + command: "sudo apt-get update && sudo apt-get install -y psmisc postgresql-contrib-12 postgresql-client-12 ruby ruby-dev libpq-dev python" - run: - name: "Build" + name: "Build" command: "cargo build" - run: name: "Test" diff --git a/.circleci/run_tests.sh b/.circleci/run_tests.sh index c932a86..a6fb790 100644 --- a/.circleci/run_tests.sh +++ b/.circleci/run_tests.sh @@ -69,8 +69,17 @@ psql -U sharding_user -e -h 127.0.0.1 -p 6432 -f tests/sharding/query_routing_te cd tests/ruby && \ sudo gem install bundler && \ bundle install && \ - ruby tests.rb && \ -cd ../.. + ruby tests.rb +cd /home/circleci/project + +# +# Python tests +# +cd tests/python && \ + pip install -r requirements.txt && \ + python tests.py +cd /home/circleci/project + # Admin tests export PGPASSWORD=admin_pass diff --git a/src/admin.rs b/src/admin.rs index 163227d..831ca0b 100644 --- a/src/admin.rs +++ b/src/admin.rs @@ -10,6 +10,18 @@ use crate::pool::get_all_pools; use crate::stats::get_stats; use crate::ClientServerMap; +pub fn generate_server_info_for_admin() -> BytesMut { + let mut server_info = BytesMut::new(); + + server_info.put(server_paramater_message("application_name", "")); + server_info.put(server_paramater_message("client_encoding", "UTF8")); + server_info.put(server_paramater_message("server_encoding", "UTF8")); + server_info.put(server_paramater_message("server_version", VERSION)); + server_info.put(server_paramater_message("DateStyle", "ISO, MDY")); + + return server_info; +} + /// Handle admin client. pub async fn handle_admin( stream: &mut T, diff --git a/src/client.rs b/src/client.rs index 4f32d0a..1775ad2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use tokio::io::{split, AsyncReadExt, BufReader, ReadHalf, WriteHalf}; use tokio::net::TcpStream; -use crate::admin::handle_admin; +use crate::admin::{generate_server_info_for_admin, handle_admin}; use crate::config::get_config; use crate::constants::*; use crate::errors::Error; @@ -311,10 +311,7 @@ where Err(_) => return Err(Error::SocketError), }; - let mut target_pool: ConnectionPool = ConnectionPool::default(); - let mut transaction_mode = false; - - if admin { + let (target_pool, transaction_mode, server_info) = if admin { let correct_user = config.general.admin_username.as_str(); let correct_password = config.general.admin_password.as_str(); @@ -325,8 +322,13 @@ where wrong_password(&mut write, user).await?; return Err(Error::ClientError); } + ( + ConnectionPool::default(), + false, + generate_server_info_for_admin(), + ) } else { - target_pool = match get_pool(database.clone(), user.clone()) { + let target_pool = match get_pool(database.clone(), user.clone()) { Some(pool) => pool, None => { error_response( @@ -340,8 +342,8 @@ where return Err(Error::ClientError); } }; - transaction_mode = target_pool.settings.pool_mode == "transaction"; - + let transaction_mode = target_pool.settings.pool_mode == "transaction"; + let server_info = target_pool.server_info(); // Compare server and client hashes. let correct_password = target_pool.settings.user.password.as_str(); let password_hash = md5_hash_password(user, correct_password, &salt); @@ -351,12 +353,13 @@ where wrong_password(&mut write, user).await?; return Err(Error::ClientError); } - } + (target_pool, transaction_mode, server_info) + }; debug!("Password authentication successful"); auth_ok(&mut write).await?; - write_all(&mut write, target_pool.server_info()).await?; + write_all(&mut write, server_info).await?; backend_key_data(&mut write, process_id, secret_key).await?; ready_for_query(&mut write).await?; diff --git a/src/messages.rs b/src/messages.rs index 89795c6..ba22a57 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -7,6 +7,7 @@ use tokio::net::TcpStream; use crate::errors::Error; use std::collections::HashMap; +use std::mem; /// Postgres data type mappings /// used in RowDescription ('T') message. @@ -498,3 +499,20 @@ where Ok(bytes) } + +pub fn server_paramater_message(key: &str, value: &str) -> BytesMut { + let mut server_info = BytesMut::new(); + + let null_byte_size = 1; + let len: usize = + mem::size_of::() + key.len() + null_byte_size + value.len() + null_byte_size; + + server_info.put_slice("S".as_bytes()); + server_info.put_i32(len.try_into().unwrap()); + server_info.put_slice(key.as_bytes()); + server_info.put_bytes(0, 1); + server_info.put_slice(value.as_bytes()); + server_info.put_bytes(0, 1); + + return server_info; +} diff --git a/tests/python/tests.py b/tests/python/tests.py index 8eb47f6..06d27dc 100644 --- a/tests/python/tests.py +++ b/tests/python/tests.py @@ -1,11 +1,22 @@ import psycopg2 -conn = psycopg2.connect("postgres://random:password@127.0.0.1:6432/db") -cur = conn.cursor() +def test_normal_db_access(): + conn = psycopg2.connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db?application_name=testing_pgcat") + cur = conn.cursor() -cur.execute("SELECT 1"); -res = cur.fetchall() + cur.execute("SELECT 1") + res = cur.fetchall() + print(res) -print(res) -# conn.commit() \ No newline at end of file +def test_admin_db_access(): + conn = psycopg2.connect("postgres://user:pass@127.0.0.1:6432/pgcat") + conn.autocommit = True # BEGIN/COMMIT is not supported by admin db + cur = conn.cursor() + + cur.execute("SHOW POOLS") + res = cur.fetchall() + print(res) + +test_normal_db_access() +test_admin_db_access() diff --git a/tests/ruby/tests.rb b/tests/ruby/tests.rb index aaabe5e..c5a55a7 100644 --- a/tests/ruby/tests.rb +++ b/tests/ruby/tests.rb @@ -128,3 +128,16 @@ end 25.times do poorly_behaved_client end + + +def test_server_parameters + server_conn = PG::connect("postgres://sharding_user:sharding_user@127.0.0.1:6432/sharded_db?application_name=testing_pgcat") + raise StandardError, "Bad server version" if server_conn.server_version == 0 + server_conn.close + + admin_conn = PG::connect("postgres://admin_user:admin_pass@127.0.0.1:6432/pgcat") + raise StandardError, "Bad server version" if admin_conn.server_version == 0 + admin_conn.close + + puts 'Server parameters ok' +end