Support settings custom search path

This commit is contained in:
Lev Kokotov
2022-08-28 17:23:28 -07:00
parent 3bc4f9351c
commit 48bb6ebeef
6 changed files with 37 additions and 5 deletions

View File

@@ -108,6 +108,7 @@ servers = [
]
# Database name (e.g. "postgres")
database = "shard0"
search_path = "\"$user\",public"
[pools.sharded_db.shards.1]
servers = [

View File

@@ -354,6 +354,8 @@ where
let stats = get_reporter();
let parameters = parse_startup(bytes.clone())?;
info!("params: {:?}", parameters);
// These two parameters are mandatory by the protocol.
let pool_name = match parameters.get("database") {
Some(db) => db,
@@ -644,8 +646,8 @@ where
// SET SHARD TO
Some((Command::SetShard, _)) => {
// Selected shard is not configured.
if query_router.shard() >= pool.shards() {
let shard = query_router.shard();
if shard >= pool.shards() {
// Set the shard back to what it was.
query_router.set_shard(current_shard);
@@ -653,7 +655,7 @@ where
&mut self.write,
&format!(
"shard {} is more than configured {}, staying on shard {}",
query_router.shard(),
shard,
pool.shards(),
current_shard,
),

View File

@@ -72,6 +72,9 @@ pub struct Address {
/// The name of the Postgres database.
pub database: String,
/// Default search_path.
pub search_path: Option<String>,
/// Server role: replica, primary.
pub role: Role,
@@ -98,6 +101,7 @@ impl Default for Address {
address_index: 0,
replica_number: 0,
database: String::from("database"),
search_path: None,
role: Role::Replica,
username: String::from("username"),
pool_name: String::from("pool_name"),
@@ -206,6 +210,7 @@ impl Default for Pool {
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct Shard {
pub database: String,
pub search_path: Option<String>,
pub servers: Vec<(String, u16, String)>,
}
@@ -213,6 +218,7 @@ impl Default for Shard {
fn default() -> Shard {
Shard {
servers: vec![(String::from("localhost"), 5432, String::from("primary"))],
search_path: None,
database: String::from("postgres"),
}
}

View File

@@ -111,7 +111,12 @@ where
/// Send the startup packet the server. We're pretending we're a Pg client.
/// This tells the server which user we are and what database we want.
pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Result<(), Error> {
pub async fn startup(
stream: &mut TcpStream,
user: &str,
database: &str,
search_path: Option<&String>,
) -> Result<(), Error> {
let mut bytes = BytesMut::with_capacity(25);
bytes.put_i32(196608); // Protocol number
@@ -125,6 +130,17 @@ pub async fn startup(stream: &mut TcpStream, user: &str, database: &str) -> Resu
bytes.put(&b"database\0"[..]);
bytes.put_slice(&database.as_bytes());
bytes.put_u8(0);
// search_path
match search_path {
Some(search_path) => {
bytes.put(&b"options\0"[..]);
bytes.put_slice(&format!("-c search_path={}", search_path).as_bytes());
bytes.put_u8(0);
}
None => (),
};
bytes.put_u8(0); // Null terminator
let len = bytes.len() as i32 + 4i32;

View File

@@ -155,6 +155,7 @@ impl ConnectionPool {
let address = Address {
id: address_id,
database: shard.database.clone(),
search_path: shard.search_path.clone(),
host: server.0.clone(),
port: server.1 as u16,
role: role,

View File

@@ -86,7 +86,13 @@ impl Server {
trace!("Sending StartupMessage");
// StartupMessage
startup(&mut stream, &user.username, database).await?;
startup(
&mut stream,
&user.username,
database,
address.search_path.as_ref(),
)
.await?;
let mut server_info = BytesMut::new();
let mut process_id: i32 = 0;