diff --git a/src/query_router.rs b/src/query_router.rs index bc6ed2c..939abee 100644 --- a/src/query_router.rs +++ b/src/query_router.rs @@ -386,6 +386,18 @@ impl QueryRouter { } } + /// Determines if a query is mutable or not. + fn query_is_mutable_statement(q: &sqlparser::ast::Query) -> bool { + use sqlparser::ast::*; + + match q.body.as_ref() { + SetExpr::Insert(_) => true, + SetExpr::Update(_) => true, + SetExpr::Query(q) => Self::query_is_mutable_statement(q), + _ => false, + } + } + /// Try to infer which server to connect to based on the contents of the query. pub fn infer(&mut self, ast: &Vec) -> Result<(), Error> { if !self.pool_settings.query_parser_read_write_splitting { @@ -428,8 +440,9 @@ impl QueryRouter { }; let has_locks = !query.locks.is_empty(); + let is_mutable_statement = Self::query_is_mutable_statement(query); - if has_locks { + if has_locks || is_mutable_statement { self.active_role = Some(Role::Primary); } else if !visited_write_statement { // If we already visited a write statement, we should be going to the primary. @@ -1113,6 +1126,26 @@ mod test { assert_eq!(qr.role(), None); } + #[test] + fn test_split_cte_queries() { + QueryRouter::setup(); + let mut qr = QueryRouter::new(); + qr.pool_settings.query_parser_read_write_splitting = true; + qr.pool_settings.query_parser_enabled = true; + + let query = simple_query( + "WITH t AS ( + SELECT id FROM users WHERE name ILIKE '%ja%' + ) + UPDATE user_languages + SET settings = '{}' + FROM t WHERE t.id = user_id;", + ); + let ast = qr.parse(&query).unwrap(); + assert!(qr.infer(&ast).is_ok()); + assert_eq!(qr.role(), Some(Role::Primary)); + } + #[test] fn test_infer_replica() { QueryRouter::setup();