diff --git a/src/client.rs b/src/client.rs index 05916a7..c5d9d3e 100644 --- a/src/client.rs +++ b/src/client.rs @@ -55,6 +55,7 @@ impl Client { // TODO: perform actual auth. // TODO: record startup parameters client sends over. auth_ok(&mut stream).await?; + server_parameters(&mut stream).await?; ready_for_query(&mut stream).await?; let (read, write) = stream.into_split(); @@ -135,6 +136,10 @@ impl Client { 'X' => { // Client closing + if server.in_transaction() { + server.query("ROLLBACK").await?; + } + return Ok(()); } diff --git a/src/errors.rs b/src/errors.rs index 47510f4..cb7e756 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -6,4 +6,5 @@ pub enum Error { ProtocolSyncError, ServerError, ServerTimeout, + DirtyServer, } diff --git a/src/messages.rs b/src/messages.rs index 86ad681..2ba4b37 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -16,6 +16,18 @@ pub async fn auth_ok(stream: &mut TcpStream) -> Result<(), Error> { Ok(write_all(stream, auth_ok).await?) } +pub async fn server_parameters(stream: &mut TcpStream) -> Result<(), Error> { + let client_encoding = BytesMut::from(&b"client_encoding\0UTF8\0"[..]); + let len = client_encoding.len() as i32 + 4; // TODO: add more parameters here + let mut res = BytesMut::with_capacity(len as usize + 1); + + res.put_u8(b'S'); + res.put_i32(len); + res.put_slice(&client_encoding[..]); + + Ok(write_all(stream, res).await?) +} + pub async fn ready_for_query(stream: &mut TcpStream) -> Result<(), Error> { let mut bytes = BytesMut::with_capacity(5); diff --git a/src/pool.rs b/src/pool.rs index b2074a4..2683e17 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -50,6 +50,11 @@ impl ManageConnection for ServerPool { async fn is_valid(&self, conn: &mut PooledConnection<'_, Self>) -> Result<(), Self::Error> { let server = &mut *conn; + // Client disconnected before cleaning up + if server.in_transaction() { + return Err(Error::DirtyServer); + } + // If this fails, the connection will be closed and another will be grabbed from the pool quietly :-). // Failover, step 1, complete. match tokio::time::timeout( diff --git a/tests/python/.gitignore b/tests/python/.gitignore new file mode 100644 index 0000000..eba74f4 --- /dev/null +++ b/tests/python/.gitignore @@ -0,0 +1 @@ +venv/ \ No newline at end of file diff --git a/tests/python/requirements.txt b/tests/python/requirements.txt new file mode 100644 index 0000000..d7661d4 --- /dev/null +++ b/tests/python/requirements.txt @@ -0,0 +1 @@ +psycopg2==2.9.3 diff --git a/tests/python/tests.py b/tests/python/tests.py new file mode 100644 index 0000000..c3fa47a --- /dev/null +++ b/tests/python/tests.py @@ -0,0 +1,9 @@ +import psycopg2 + +conn = psycopg2.connect("postgres://random:password@127.0.0.1:5433/db") +cur = conn.cursor() + +cur.execute("SELECT 123"); +res = cur.fetchall() + +print(res) \ No newline at end of file