Support unnamed prepared statements (#635)

* Add golang test suite to reproduce issue with unnamed parameterized prepared statements

* Allow caching of unnamed prepared statements

* Passthrough describe on portals

* Remove unneeded kill

* Update Dockerfile.ci with golang

* Move out update of Dockerfiles to separate PR
This commit is contained in:
Jakob Schultz-Falk
2023-11-09 01:36:45 +01:00
committed by GitHub
parent b45c6b1d23
commit 7c37da2fad
8 changed files with 327 additions and 19 deletions

View File

@@ -108,6 +108,15 @@ cd ../..
pip3 install -r tests/python/requirements.txt
python3 tests/python/tests.py || exit 1
#
# Go tests
# Starts its own pgcat server
#
pushd tests/go
/usr/local/go/bin/go test || exit 1
popd
start_pgcat "info"
# Admin tests

View File

@@ -1704,18 +1704,14 @@ where
/// and also the pool's statement cache. Add it to extended protocol data.
fn buffer_parse(&mut self, message: BytesMut, pool: &ConnectionPool) -> Result<(), Error> {
// Avoid parsing if prepared statements not enabled
let client_given_name = match self.prepared_statements_enabled {
true => Parse::get_name(&message)?,
false => "".to_string(),
};
if client_given_name.is_empty() {
if !self.prepared_statements_enabled {
debug!("Anonymous parse message");
self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_parse(message, None));
return Ok(());
}
let client_given_name = Parse::get_name(&message)?;
let parse: Parse = (&message).try_into()?;
// Compute the hash of the parse statement
@@ -1753,18 +1749,15 @@ where
/// saved in the client cache.
async fn buffer_bind(&mut self, message: BytesMut) -> Result<(), Error> {
// Avoid parsing if prepared statements not enabled
let client_given_name = match self.prepared_statements_enabled {
true => Bind::get_name(&message)?,
false => "".to_string(),
};
if client_given_name.is_empty() {
if !self.prepared_statements_enabled {
debug!("Anonymous bind message");
self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_bind(message, None));
return Ok(());
}
let client_given_name = Bind::get_name(&message)?;
match self.prepared_statements.get(&client_given_name) {
Some((rewritten_parse, _)) => {
let message = Bind::rename(message, &rewritten_parse.name)?;
@@ -1807,12 +1800,7 @@ where
/// saved in the client cache.
async fn buffer_describe(&mut self, message: BytesMut) -> Result<(), Error> {
// Avoid parsing if prepared statements not enabled
let describe: Describe = match self.prepared_statements_enabled {
true => (&message).try_into()?,
false => Describe::empty_new(),
};
if describe.anonymous() {
if !self.prepared_statements_enabled {
debug!("Anonymous describe message");
self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_describe(message, None));
@@ -1820,6 +1808,15 @@ where
return Ok(());
}
let describe: Describe = (&message).try_into()?;
if describe.target == 'P' {
debug!("Portal describe message");
self.extended_protocol_data_buffer
.push_back(ExtendedProtocolData::create_new_describe(message, None));
return Ok(());
}
let client_given_name = describe.statement_name.clone();
match self.prepared_statements.get(&client_given_name) {

View File

@@ -1109,7 +1109,7 @@ pub struct Describe {
#[allow(dead_code)]
len: i32,
target: char,
pub target: char,
pub statement_name: String,
}

5
tests/go/go.mod Normal file
View File

@@ -0,0 +1,5 @@
module pgcat
go 1.21
require github.com/lib/pq v1.10.9

2
tests/go/go.sum Normal file
View File

@@ -0,0 +1,2 @@
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=

162
tests/go/pgcat.toml Normal file
View File

@@ -0,0 +1,162 @@
#
# PgCat config example.
#
#
# General pooler settings
[general]
# What IP to run on, 0.0.0.0 means accessible from everywhere.
host = "0.0.0.0"
# Port to run on, same as PgBouncer used in this example.
port = "${PORT}"
# Whether to enable prometheus exporter or not.
enable_prometheus_exporter = true
# Port at which prometheus exporter listens on.
prometheus_exporter_port = 9930
# How long to wait before aborting a server connection (ms).
connect_timeout = 1000
# How much time to give the health check query to return with a result (ms).
healthcheck_timeout = 1000
# How long to keep connection available for immediate re-use, without running a healthcheck query on it
healthcheck_delay = 30000
# How much time to give clients during shutdown before forcibly killing client connections (ms).
shutdown_timeout = 5000
# For how long to ban a server if it fails a health check (seconds).
ban_time = 60 # Seconds
# If we should log client connections
log_client_connections = false
# If we should log client disconnections
log_client_disconnections = false
# Reload config automatically if it changes.
autoreload = 15000
server_round_robin = false
# TLS
tls_certificate = "../../.circleci/server.cert"
tls_private_key = "../../.circleci/server.key"
# Credentials to access the virtual administrative database (pgbouncer or pgcat)
# Connecting to that database allows running commands like `SHOW POOLS`, `SHOW DATABASES`, etc..
admin_username = "admin_user"
admin_password = "admin_pass"
# pool
# configs are structured as pool.<pool_name>
# the pool_name is what clients use as database name when connecting
# For the example below a client can connect using "postgres://sharding_user:sharding_user@pgcat_host:pgcat_port/sharded_db"
[pools.sharded_db]
# Pool mode (see PgBouncer docs for more).
# session: one server connection per connected client
# transaction: one server connection per client transaction
pool_mode = "transaction"
# If the client doesn't specify, route traffic to
# this role by default.
#
# any: round-robin between primary and replicas,
# replica: round-robin between replicas only without touching the primary,
# primary: all queries go to the primary unless otherwise specified.
default_role = "any"
# Query parser. If enabled, we'll attempt to parse
# every incoming query to determine if it's a read or a write.
# If it's a read query, we'll direct it to a replica. Otherwise, if it's a write,
# we'll direct it to the primary.
query_parser_enabled = true
# If the query parser is enabled and this setting is enabled, we'll attempt to
# infer the role from the query itself.
query_parser_read_write_splitting = true
# If the query parser is enabled and this setting is enabled, the primary will be part of the pool of databases used for
# load balancing of read queries. Otherwise, the primary will only be used for write
# queries. The primary can always be explicitely selected with our custom protocol.
primary_reads_enabled = true
# So what if you wanted to implement a different hashing function,
# or you've already built one and you want this pooler to use it?
#
# Current options:
#
# pg_bigint_hash: PARTITION BY HASH (Postgres hashing function)
# sha1: A hashing function based on SHA1
#
sharding_function = "pg_bigint_hash"
# Prepared statements cache size.
prepared_statements_cache_size = 500
# Credentials for users that may connect to this cluster
[pools.sharded_db.users.0]
username = "sharding_user"
password = "sharding_user"
# Maximum number of server connections that can be established for this user
# The maximum number of connection from a single Pgcat process to any database in the cluster
# is the sum of pool_size across all users.
pool_size = 5
statement_timeout = 0
[pools.sharded_db.users.1]
username = "other_user"
password = "other_user"
pool_size = 21
statement_timeout = 30000
# Shard 0
[pools.sharded_db.shards.0]
# [ host, port, role ]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ]
]
# Database name (e.g. "postgres")
database = "shard0"
[pools.sharded_db.shards.1]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard1"
[pools.sharded_db.shards.2]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ],
]
database = "shard2"
[pools.simple_db]
pool_mode = "session"
default_role = "primary"
query_parser_enabled = true
query_parser_read_write_splitting = true
primary_reads_enabled = true
sharding_function = "pg_bigint_hash"
[pools.simple_db.users.0]
username = "simple_user"
password = "simple_user"
pool_size = 5
statement_timeout = 30000
[pools.simple_db.shards.0]
servers = [
[ "127.0.0.1", 5432, "primary" ],
[ "localhost", 5432, "replica" ]
]
database = "some_db"

52
tests/go/prepared_test.go Normal file
View File

@@ -0,0 +1,52 @@
package pgcat
import (
"context"
"database/sql"
"fmt"
_ "github.com/lib/pq"
"testing"
)
func Test(t *testing.T) {
t.Cleanup(setup(t))
t.Run("Named parameterized prepared statement works", namedParameterizedPreparedStatement)
t.Run("Unnamed parameterized prepared statement works", unnamedParameterizedPreparedStatement)
}
func namedParameterizedPreparedStatement(t *testing.T) {
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=sharded_db user=sharding_user password=sharding_user sslmode=disable", port))
if err != nil {
t.Fatalf("could not open connection: %+v", err)
}
stmt, err := db.Prepare("SELECT $1")
if err != nil {
t.Fatalf("could not prepare: %+v", err)
}
for i := 0; i < 100; i++ {
rows, err := stmt.Query(1)
if err != nil {
t.Fatalf("could not query: %+v", err)
}
_ = rows.Close()
}
}
func unnamedParameterizedPreparedStatement(t *testing.T) {
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=sharded_db user=sharding_user password=sharding_user sslmode=disable", port))
if err != nil {
t.Fatalf("could not open connection: %+v", err)
}
for i := 0; i < 100; i++ {
// Under the hood QueryContext generates an unnamed parameterized prepared statement
rows, err := db.QueryContext(context.Background(), "SELECT $1", 1)
if err != nil {
t.Fatalf("could not query: %+v", err)
}
_ = rows.Close()
}
}

81
tests/go/setup.go Normal file
View File

@@ -0,0 +1,81 @@
package pgcat
import (
"context"
"database/sql"
_ "embed"
"fmt"
"math/rand"
"os"
"os/exec"
"strings"
"testing"
"time"
)
//go:embed pgcat.toml
var pgcatCfg string
var port = rand.Intn(32760-20000) + 20000
func setup(t *testing.T) func() {
cfg, err := os.CreateTemp("/tmp", "pgcat_cfg_*.toml")
if err != nil {
t.Fatalf("could not create temp file: %+v", err)
}
pgcatCfg = strings.Replace(pgcatCfg, "\"${PORT}\"", fmt.Sprintf("%d", port), 1)
_, err = cfg.Write([]byte(pgcatCfg))
if err != nil {
t.Fatalf("could not write temp file: %+v", err)
}
commandPath := "../../target/debug/pgcat"
if os.Getenv("CARGO_TARGET_DIR") != "" {
commandPath = os.Getenv("CARGO_TARGET_DIR") + "/debug/pgcat"
}
cmd := exec.Command(commandPath, cfg.Name())
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
go func() {
err = cmd.Run()
if err != nil {
t.Errorf("could not run pgcat: %+v", err)
}
}()
deadline, cancelFunc := context.WithDeadline(context.Background(), time.Now().Add(5*time.Second))
defer cancelFunc()
for {
select {
case <-deadline.Done():
break
case <-time.After(50 * time.Millisecond):
db, err := sql.Open("postgres", fmt.Sprintf("host=localhost port=%d database=pgcat user=admin_user password=admin_pass sslmode=disable", port))
if err != nil {
continue
}
rows, err := db.QueryContext(deadline, "SHOW STATS")
if err != nil {
continue
}
_ = rows.Close()
_ = db.Close()
break
}
break
}
return func() {
err := cmd.Process.Signal(os.Interrupt)
if err != nil {
t.Fatalf("could not interrupt pgcat: %+v", err)
}
err = os.Remove(cfg.Name())
if err != nil {
t.Fatalf("could not remove temp file: %+v", err)
}
}
}