Compare commits

...

2 Commits

Author SHA1 Message Date
Lev Kokotov
5872354c3e remove debug 2022-08-28 17:29:13 -07:00
Lev Kokotov
48bb6ebeef Support settings custom search path 2022-08-28 17:23:28 -07:00
6 changed files with 35 additions and 5 deletions

View File

@@ -108,6 +108,7 @@ servers = [
] ]
# Database name (e.g. "postgres") # Database name (e.g. "postgres")
database = "shard0" database = "shard0"
search_path = "\"$user\",public"
[pools.sharded_db.shards.1] [pools.sharded_db.shards.1]
servers = [ servers = [

View File

@@ -644,8 +644,8 @@ where
// SET SHARD TO // SET SHARD TO
Some((Command::SetShard, _)) => { Some((Command::SetShard, _)) => {
// Selected shard is not configured. let shard = query_router.shard();
if query_router.shard() >= pool.shards() { if shard >= pool.shards() {
// Set the shard back to what it was. // Set the shard back to what it was.
query_router.set_shard(current_shard); query_router.set_shard(current_shard);
@@ -653,7 +653,7 @@ where
&mut self.write, &mut self.write,
&format!( &format!(
"shard {} is more than configured {}, staying on shard {}", "shard {} is more than configured {}, staying on shard {}",
query_router.shard(), shard,
pool.shards(), pool.shards(),
current_shard, current_shard,
), ),

View File

@@ -72,6 +72,9 @@ pub struct Address {
/// The name of the Postgres database. /// The name of the Postgres database.
pub database: String, pub database: String,
/// Default search_path.
pub search_path: Option<String>,
/// Server role: replica, primary. /// Server role: replica, primary.
pub role: Role, pub role: Role,
@@ -98,6 +101,7 @@ impl Default for Address {
address_index: 0, address_index: 0,
replica_number: 0, replica_number: 0,
database: String::from("database"), database: String::from("database"),
search_path: None,
role: Role::Replica, role: Role::Replica,
username: String::from("username"), username: String::from("username"),
pool_name: String::from("pool_name"), pool_name: String::from("pool_name"),
@@ -206,6 +210,7 @@ impl Default for Pool {
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Shard { pub struct Shard {
pub database: String, pub database: String,
pub search_path: Option<String>,
pub servers: Vec<(String, u16, String)>, pub servers: Vec<(String, u16, String)>,
} }
@@ -213,6 +218,7 @@ impl Default for Shard {
fn default() -> Shard { fn default() -> Shard {
Shard { Shard {
servers: vec![(String::from("localhost"), 5432, String::from("primary"))], servers: vec![(String::from("localhost"), 5432, String::from("primary"))],
search_path: None,
database: String::from("postgres"), database: String::from("postgres"),
} }
} }

View File

@@ -111,7 +111,12 @@ where
/// Send the startup packet the server. We're pretending we're a Pg client. /// Send the startup packet the server. We're pretending we're a Pg client.
/// This tells the server which user we are and what database we want. /// This tells the server which user we are and what database we want.
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> { pub async fn startup(
stream: &mut TcpStream,
user: &str,
database: &str,
search_path: Option<&String>,
) -> Result<(), Error> {
let mut bytes = BytesMut::with_capacity(25); let mut bytes = BytesMut::with_capacity(25);
bytes.put_i32(196608); // Protocol number bytes.put_i32(196608); // Protocol number
@@ -125,6 +130,17 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
bytes.put(&b"database\0"[..]); bytes.put(&b"database\0"[..]);
bytes.put_slice(&database.as_bytes()); bytes.put_slice(&database.as_bytes());
bytes.put_u8(0); bytes.put_u8(0);
// search_path
match search_path {
Some(search_path) => {
bytes.put(&b"options\0"[..]);
bytes.put_slice(&format!("-c search_path={}", search_path).as_bytes());
bytes.put_u8(0);
}
None => (),
};
bytes.put_u8(0); // Null terminator bytes.put_u8(0); // Null terminator
let len = bytes.len() as i32 + 4i32; let len = bytes.len() as i32 + 4i32;

View File

@@ -155,6 +155,7 @@ impl ConnectionPool {
let address = Address { let address = Address {
id: address_id, id: address_id,
database: shard.database.clone(), database: shard.database.clone(),
search_path: shard.search_path.clone(),
host: server.0.clone(), host: server.0.clone(),
port: server.1 as u16, port: server.1 as u16,
role: role, role: role,

View File

@@ -86,7 +86,13 @@ impl Server {
trace!("Sending StartupMessage"); trace!("Sending StartupMessage");
// StartupMessage // StartupMessage
startup(&mut stream, &user.username, database).await?; startup(
&mut stream,
&user.username,
database,
address.search_path.as_ref(),
)
.await?;
let mut server_info = BytesMut::new(); let mut server_info = BytesMut::new();
let mut process_id: i32 = 0; let mut process_id: i32 = 0;