mirror of
https://github.com/postgresml/pgcat.git
synced 2026-03-22 17:06:29 +00:00
Support unnamed prepared statements (#635)
* Add golang test suite to reproduce issue with unnamed parameterized prepared statements * Allow caching of unnamed prepared statements * Passthrough describe on portals * Remove unneeded kill * Update Dockerfile.ci with golang * Move out update of Dockerfiles to separate PR
This commit is contained in:
committed by
GitHub
parent
b45c6b1d23
commit
7c37da2fad
@@ -108,6 +108,15 @@ cd ../..
|
||||
pip3 install -r tests/python/requirements.txt
|
||||
python3 tests/python/tests.py || exit 1
|
||||
|
||||
|
||||
#
|
||||
# Go tests
|
||||
# Starts its own pgcat server
|
||||
#
|
||||
pushd tests/go
|
||||
/usr/local/go/bin/go test || exit 1
|
||||
popd
|
||||
|
||||
start_pgcat "info"
|
||||
|
||||
# Admin tests
|
||||
|
||||
@@ -1704,18 +1704,14 @@ where
|
||||
/// and also the pool's statement cache. Add it to extended protocol data.
|
||||
fn buffer_parse(&mut self, message: BytesMut, pool: &ConnectionPool) -> Result<(), Error> {
|
||||
// Avoid parsing if prepared statements not enabled
|
||||
let client_given_name = match self.prepared_statements_enabled {
|
||||
true => Parse::get_name(&message)?,
|
||||
false => "".to_string(),
|
||||
};
|
||||
|
||||
if client_given_name.is_empty() {
|
||||
if !self.prepared_statements_enabled {
|
||||
debug!("Anonymous parse message");
|
||||
self.extended_protocol_data_buffer
|
||||
.push_back(ExtendedProtocolData::create_new_parse(message, None));
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let client_given_name = Parse::get_name(&message)?;
|
||||
let parse: Parse = (&message).try_into()?;
|
||||
|
||||
// Compute the hash of the parse statement
|
||||
@@ -1753,18 +1749,15 @@ where
|
||||
/// saved in the client cache.
|
||||
async fn buffer_bind(&mut self, message: BytesMut) -> Result<(), Error> {
|
||||
// Avoid parsing if prepared statements not enabled
|
||||
let client_given_name = match self.prepared_statements_enabled {
|
||||
true => Bind::get_name(&message)?,
|
||||
false => "".to_string(),
|
||||
};
|
||||
|
||||
if client_given_name.is_empty() {
|
||||
if !self.prepared_statements_enabled {
|
||||
debug!("Anonymous bind message");
|
||||
self.extended_protocol_data_buffer
|
||||
.push_back(ExtendedProtocolData::create_new_bind(message, None));
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let client_given_name = Bind::get_name(&message)?;
|
||||
|
||||
match self.prepared_statements.get(&client_given_name) {
|
||||
Some((rewritten_parse, _)) => {
|
||||
let message = Bind::rename(message, &rewritten_parse.name)?;
|
||||
@@ -1807,12 +1800,7 @@ where
|
||||
/// saved in the client cache.
|
||||
async fn buffer_describe(&mut self, message: BytesMut) -> Result<(), Error> {
|
||||
// Avoid parsing if prepared statements not enabled
|
||||
let describe: Describe = match self.prepared_statements_enabled {
|
||||
true => (&message).try_into()?,
|
||||
false => Describe::empty_new(),
|
||||
};
|
||||
|
||||
if describe.anonymous() {
|
||||
if !self.prepared_statements_enabled {
|
||||
debug!("Anonymous describe message");
|
||||
self.extended_protocol_data_buffer
|
||||
.push_back(ExtendedProtocolData::create_new_describe(message, None));
|
||||
@@ -1820,6 +1808,15 @@ where
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let describe: Describe = (&message).try_into()?;
|
||||
if describe.target == 'P' {
|
||||
debug!("Portal describe message");
|
||||
self.extended_protocol_data_buffer
|
||||
.push_back(ExtendedProtocolData::create_new_describe(message, None));
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let client_given_name = describe.statement_name.clone();
|
||||
|
||||
match self.prepared_statements.get(&client_given_name) {
|
||||
|
||||
@@ -1109,7 +1109,7 @@ pub struct Describe {
|
||||
|
||||
#[allow(dead_code)]
|
||||
len: i32,
|
||||
target: char,
|
||||
pub target: char,
|
||||
pub statement_name: String,
|
||||
}
|
||||
|
||||
|
||||
5
tests/go/go.mod
Normal file
5
tests/go/go.mod
Normal file
@@ -0,0 +1,5 @@
|
||||
module pgcat
|
||||
|
||||
go 1.21
|
||||
|
||||
require github.com/lib/pq v1.10.9
|
||||
2
tests/go/go.sum
Normal file
2
tests/go/go.sum
Normal file
@@ -0,0 +1,2 @@
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
162
tests/go/pgcat.toml
Normal file
162
tests/go/pgcat.toml
Normal file
@@ -0,0 +1,162 @@
|
||||
#
|
||||
# PgCat config example.
|
||||
#
|
||||
|
||||
#
|
||||
# General pooler settings
|
||||
[general]
|
||||
# What IP to run on, 0.0.0.0 means accessible from everywhere.
|
||||
host = "0.0.0.0"
|
||||
|
||||
# Port to run on, same as PgBouncer used in this example.
|
||||
port = "${PORT}"
|
||||
|
||||
# Whether to enable prometheus exporter or not.
|
||||
enable_prometheus_exporter = true
|
||||
|
||||
# Port at which prometheus exporter listens on.
|
||||
prometheus_exporter_port = 9930
|
||||
|
||||
# How long to wait before aborting a server connection (ms).
|
||||
connect_timeout = 1000
|
||||
|
||||
# How much time to give the health check query to return with a result (ms).
|
||||
healthcheck_timeout = 1000
|
||||
|
||||
# How long to keep connection available for immediate re-use, without running a healthcheck query on it
|
||||
healthcheck_delay = 30000
|
||||
|
||||
# How much time to give clients during shutdown before forcibly killing client connections (ms).
|
||||
shutdown_timeout = 5000
|
||||
|
||||
# For how long to ban a server if it fails a health check (seconds).
|
||||
ban_time = 60 # Seconds
|
||||
|
||||
# If we should log client connections
|
||||
log_client_connections = false
|
||||
|
||||
# If we should log client disconnections
|
||||
log_client_disconnections = false
|
||||
|
||||
# Reload config automatically if it changes.
|
||||
autoreload = 15000
|
||||
|
||||
server_round_robin = false
|
||||
|
||||
# TLS
|
||||
tls_certificate = "../../.circleci/server.cert"
|
||||
tls_private_key = "../../.circleci/server.key"
|
||||
|
||||
# Credentials to access the virtual administrative database (pgbouncer or pgcat)
|
||||
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
|
||||
admin_username = "admin_user"
|
||||
admin_password = "admin_pass"
|
||||
|
||||
# pool
|
||||
# configs are structured as pool.<pool_name>
|
||||
# the pool_name is what clients use as database name when connecting
|
||||
# For the example below a client can connect using "postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded_db"
|
||||
[pools.sharded_db]
|
||||
# Pool mode (see PgBouncer docs for more).
|
||||
# session: one server connection per connected client
|
||||
# transaction: one server connection per client transaction
|
||||
pool_mode = "transaction"
|
||||
|
||||
# If the client doesn't specify, route traffic to
|
||||
# this role by default.
|
||||
#
|
||||
# any: round-robin between primary and replicas,
|
||||
# replica: round-robin between replicas only without touching the primary,
|
||||
# primary: all queries go to the primary unless otherwise specified.
|
||||
default_role = "any"
|
||||
|
||||
# Query parser. If enabled, we'll attempt to parse
|
||||
# every incoming query to determine if it's a read or a write.
|
||||
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
|
||||
# we'll direct it to the primary.
|
||||
query_parser_enabled = true
|
||||
|
||||
# If the query parser is enabled and this setting is enabled, we'll attempt to
|
||||
# infer the role from the query itself.
|
||||
query_parser_read_write_splitting = true
|
||||
|
||||
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
|
||||
# load balancing of read queries. Otherwise, the primary will only be used for write
|
||||
# queries. The primary can always be explicitely selected with our custom protocol.
|
||||
primary_reads_enabled = true
|
||||
|
||||
# So what if you wanted to implement a different hashing function,
|
||||
# or you've already built one and you want this pooler to use it?
|
||||
#
|
||||
# Current options:
|
||||
#
|
||||
# pg_bigint_hash: PARTITION BY HASH (Postgres hashing function)
|
||||
# sha1: A hashing function based on SHA1
|
||||
#
|
||||
sharding_function = "pg_bigint_hash"
|
||||
|
||||
# Prepared statements cache size.
|
||||
prepared_statements_cache_size = 500
|
||||
|
||||
# Credentials for users that may connect to this cluster
|
||||
[pools.sharded_db.users.0]
|
||||
username = "sharding_user"
|
||||
password = "sharding_user"
|
||||
# Maximum number of server connections that can be established for this user
|
||||
# The maximum number of connection from a single Pgcat process to any database in the cluster
|
||||
# is the sum of pool_size across all users.
|
||||
pool_size = 5
|
||||
statement_timeout = 0
|
||||
|
||||
|
||||
[pools.sharded_db.users.1]
|
||||
username = "other_user"
|
||||
password = "other_user"
|
||||
pool_size = 21
|
||||
statement_timeout = 30000
|
||||
|
||||
# Shard 0
|
||||
[pools.sharded_db.shards.0]
|
||||
# [ host, port, role ]
|
||||
servers = [
|
||||
[ "127.0.0.1", 5432, "primary" ],
|
||||
[ "localhost", 5432, "replica" ]
|
||||
]
|
||||
# Database name (e.g. "postgres")
|
||||
database = "shard0"
|
||||
|
||||
[pools.sharded_db.shards.1]
|
||||
servers = [
|
||||
[ "127.0.0.1", 5432, "primary" ],
|
||||
[ "localhost", 5432, "replica" ],
|
||||
]
|
||||
database = "shard1"
|
||||
|
||||
[pools.sharded_db.shards.2]
|
||||
servers = [
|
||||
[ "127.0.0.1", 5432, "primary" ],
|
||||
[ "localhost", 5432, "replica" ],
|
||||
]
|
||||
database = "shard2"
|
||||
|
||||
|
||||
[pools.simple_db]
|
||||
pool_mode = "session"
|
||||
default_role = "primary"
|
||||
query_parser_enabled = true
|
||||
query_parser_read_write_splitting = true
|
||||
primary_reads_enabled = true
|
||||
sharding_function = "pg_bigint_hash"
|
||||
|
||||
[pools.simple_db.users.0]
|
||||
username = "simple_user"
|
||||
password = "simple_user"
|
||||
pool_size = 5
|
||||
statement_timeout = 30000
|
||||
|
||||
[pools.simple_db.shards.0]
|
||||
servers = [
|
||||
[ "127.0.0.1", 5432, "primary" ],
|
||||
[ "localhost", 5432, "replica" ]
|
||||
]
|
||||
database = "some_db"
|
||||
52
tests/go/prepared_test.go
Normal file
52
tests/go/prepared_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package pgcat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
_ "github.com/lib/pq"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test(t *testing.T) {
|
||||
t.Cleanup(setup(t))
|
||||
t.Run("Named parameterized prepared statement works", namedParameterizedPreparedStatement)
|
||||
t.Run("Unnamed parameterized prepared statement works", unnamedParameterizedPreparedStatement)
|
||||
}
|
||||
|
||||
func namedParameterizedPreparedStatement(t *testing.T) {
|
||||
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=sharded_db user=sharding_user password=sharding_user sslmode=disable", port))
|
||||
if err != nil {
|
||||
t.Fatalf("could not open connection: %+v", err)
|
||||
}
|
||||
|
||||
stmt, err := db.Prepare("SELECT $1")
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("could not prepare: %+v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
rows, err := stmt.Query(1)
|
||||
if err != nil {
|
||||
t.Fatalf("could not query: %+v", err)
|
||||
}
|
||||
_ = rows.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func unnamedParameterizedPreparedStatement(t *testing.T) {
|
||||
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=sharded_db user=sharding_user password=sharding_user sslmode=disable", port))
|
||||
if err != nil {
|
||||
t.Fatalf("could not open connection: %+v", err)
|
||||
}
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
// Under the hood QueryContext generates an unnamed parameterized prepared statement
|
||||
rows, err := db.QueryContext(context.Background(), "SELECT $1", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("could not query: %+v", err)
|
||||
}
|
||||
_ = rows.Close()
|
||||
}
|
||||
}
|
||||
81
tests/go/setup.go
Normal file
81
tests/go/setup.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package pgcat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
//go:embed pgcat.toml
|
||||
var pgcatCfg string
|
||||
|
||||
var port = rand.Intn(32760-20000) + 20000
|
||||
|
||||
func setup(t *testing.T) func() {
|
||||
cfg, err := os.CreateTemp("/tmp", "pgcat_cfg_*.toml")
|
||||
if err != nil {
|
||||
t.Fatalf("could not create temp file: %+v", err)
|
||||
}
|
||||
|
||||
pgcatCfg = strings.Replace(pgcatCfg, "\"${PORT}\"", fmt.Sprintf("%d", port), 1)
|
||||
|
||||
_, err = cfg.Write([]byte(pgcatCfg))
|
||||
if err != nil {
|
||||
t.Fatalf("could not write temp file: %+v", err)
|
||||
}
|
||||
|
||||
commandPath := "../../target/debug/pgcat"
|
||||
if os.Getenv("CARGO_TARGET_DIR") != "" {
|
||||
commandPath = os.Getenv("CARGO_TARGET_DIR") + "/debug/pgcat"
|
||||
}
|
||||
|
||||
cmd := exec.Command(commandPath, cfg.Name())
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
go func() {
|
||||
err = cmd.Run()
|
||||
if err != nil {
|
||||
t.Errorf("could not run pgcat: %+v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
deadline, cancelFunc := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
|
||||
defer cancelFunc()
|
||||
for {
|
||||
select {
|
||||
case <-deadline.Done():
|
||||
break
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=pgcat user=admin_user password=admin_pass sslmode=disable", port))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
rows, err := db.QueryContext(deadline, "SHOW STATS")
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
_ = rows.Close()
|
||||
_ = db.Close()
|
||||
break
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
return func() {
|
||||
err := cmd.Process.Signal(os.Interrupt)
|
||||
if err != nil {
|
||||
t.Fatalf("could not interrupt pgcat: %+v", err)
|
||||
}
|
||||
err = os.Remove(cfg.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("could not remove temp file: %+v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user