From bc8cbf3e864a2ecc577737c28d441deb293979f2 Mon Sep 17 00:00:00 2001 From: Yogesh Singla Date: Fri, 26 Jun 2026 23:14:07 +0530 Subject: [PATCH 1/3] feat: supported multiple rows update for sharding key update --- .../client/query_engine/multi_step/update.rs | 129 ++++++++---------- 1 file changed, 58 insertions(+), 71 deletions(-) diff --git a/pgdog/src/frontend/client/query_engine/multi_step/update.rs b/pgdog/src/frontend/client/query_engine/multi_step/update.rs index 9570d2b24..255ae1fe8 100644 --- a/pgdog/src/frontend/client/query_engine/multi_step/update.rs +++ b/pgdog/src/frontend/client/query_engine/multi_step/update.rs @@ -13,9 +13,9 @@ use crate::{ use super::{Error, ForwardCheck, UpdateError}; #[derive(Debug, Clone, Default)] -pub(super) struct Row { - data_row: DataRow, +pub(super) struct Rows { row_description: RowDescription, + data_rows: Vec, } #[derive(Debug)] @@ -77,11 +77,11 @@ impl<'a> UpdateMulti<'a> { return Ok(()); } - // Fetch the old row from whatever shard it is on. - let row = self.fetch_row(context).await?; + // Fetch the old rows from whatever shard(s) they are on. + let rows = self.fetch_rows(context).await?; - if let Some(row) = row { - self.insert_row(context, row).await?; + if let Some(rows) = rows { + self.insert_rows(context, rows).await?; } else { // This happens, but the UPDATE's WHERE clause // doesn't match any rows, so this whole thing is a no-op. @@ -93,72 +93,67 @@ impl<'a> UpdateMulti<'a> { Ok(()) } - /// Create row. - pub(super) async fn insert_row( + pub(super) async fn insert_rows( &mut self, context: &mut QueryEngineContext<'_>, - row: Row, + rows: Rows, ) -> Result<(), Error> { - let mut request = self.rewrite.insert.build_request( - context.client_request, - &row.row_description, - &row.data_row, - )?; - self.route(&mut request, context)?; + debug!("[update] executing multi-shard insert/delete"); - let original_shard = context.client_request.route().shard(); - let new_shard = request.route().shard(); + // Check if we are allowed to do this operation by the config. + if self.engine.backend.cluster()?.rewrite().shard_key == RewriteMode::Error { + self.engine + .error_response(context, ErrorResponse::from_err(&UpdateError::Disabled)) + .await?; + return Ok(()); + } - // The new row maps to the same shard as the old row. - // We don't need to do the multi-step UPDATE anymore. - // Forward the original request as-is. - if original_shard.is_direct() && new_shard == original_shard { - debug!("[update] selected row is on the same shard"); - self.execute_original(context).await - } else { - debug!("[update] executing multi-shard insert/delete"); - - // Check if we are allowed to do this operation by the config. - if self.engine.backend.cluster()?.rewrite().shard_key == RewriteMode::Error { - self.engine - .error_response(context, ErrorResponse::from_err(&UpdateError::Disabled)) - .await?; - return Ok(()); - } + if !context.in_transaction() || !self.engine.backend.is_multishard() + // Do this check at the last possible moment. + // Just in case we change how transactions are + // routed in the future. + { + self.engine.cleanup_backend(context)?; + return Err(UpdateError::TransactionRequired.into()); + } - if !context.in_transaction() || !self.engine.backend.is_multishard() - // Do this check at the last possible moment. - // Just in case we change how transactions are - // routed in the future. - { - self.engine.cleanup_backend(context)?; - return Err(UpdateError::TransactionRequired.into()); - } + if self.has_destructive_on_delete_reference(context)? { + return Err(UpdateError::ForeignKeyOnDelete.into()); + } - if self.has_destructive_on_delete_reference(context)? { - return Err(UpdateError::ForeignKeyOnDelete.into()); - } + // Delete all matching rows from the original shard in one shot. + self.delete_row(context).await?; - self.delete_row(context).await?; + let n = rows.data_rows.len(); + for data_row in rows.data_rows { + let mut request = self.rewrite.insert.build_request( + context.client_request, + &rows.row_description, + &data_row, + )?; + self.route(&mut request, context)?; self.execute_request_internal( context, &mut request, self.rewrite.insert.is_returning(), ) .await?; + } - self.engine - .process_server_message(context, CommandComplete::new("UPDATE 1").message()?) // We only allow to update one row at a time. - .await?; - self.engine - .process_server_message( - context, - ReadyForQuery::in_transaction(context.in_transaction()).message()?, - ) - .await?; + self.engine + .process_server_message( + context, + CommandComplete::new(&format!("UPDATE {n}")).message()?, + ) + .await?; + self.engine + .process_server_message( + context, + ReadyForQuery::in_transaction(context.in_transaction()).message()?, + ) + .await?; - Ok(()) - } + Ok(()) } fn has_destructive_on_delete_reference( @@ -249,10 +244,10 @@ impl<'a> UpdateMulti<'a> { .await } - pub(super) async fn fetch_row( + pub(super) async fn fetch_rows( &mut self, context: &mut QueryEngineContext<'_>, - ) -> Result, Error> { + ) -> Result, Error> { let mut request = self.rewrite.select.build_request(context.client_request)?; self.route(&mut request, context)?; @@ -261,29 +256,21 @@ impl<'a> UpdateMulti<'a> { .handle_client_request(&request, &mut Router::default(), false) .await?; - let mut row = Row::default(); - let mut rows = 0; + let mut rows = Rows::default(); while self.engine.backend.has_more_messages() { let message = self.engine.read_server_message().await?; match message.code() { - 'D' => { - row.data_row = DataRow::try_from(message)?; - rows += 1; - } - 'T' => row.row_description = RowDescription::try_from(message)?, + 'D' => rows.data_rows.push(DataRow::try_from(message)?), + 'T' => rows.row_description = RowDescription::try_from(message)?, 'E' => return Err(ErrorResponse::try_from(message)?.into()), _ => (), } } - match rows { - 0 => return Ok(None), - 1 => (), - n => return Err(UpdateError::TooManyRows(n).into()), - } + debug!("[update] found {} rows to move", rows.data_rows.len()); - Ok(Some(row)) + Ok((!rows.data_rows.is_empty()).then_some(rows)) } /// Returns true if the new sharding key resides on the same shard From a3aa0da955750844cdd3373837fa6dd6f69f5e4f Mon Sep 17 00:00:00 2001 From: Yogesh Singla Date: Sat, 27 Jun 2026 00:08:45 +0530 Subject: [PATCH 2/3] add same shard check and tests --- integration/rust/tests/integration/rewrite.rs | 45 ++++++++++--------- .../client/query_engine/multi_step/update.rs | 37 +++++++++++---- 2 files changed, 53 insertions(+), 29 deletions(-) diff --git a/integration/rust/tests/integration/rewrite.rs b/integration/rust/tests/integration/rewrite.rs index ac1635b20..8929ac564 100644 --- a/integration/rust/tests/integration/rewrite.rs +++ b/integration/rust/tests/integration/rewrite.rs @@ -178,14 +178,15 @@ async fn update_moves_row_between_shards() { } #[tokio::test] -async fn update_rejects_multiple_rows() { +async fn update_multiple_rows() { let admin = admin_sqlx().await; let _guard = RewriteConfigGuard::enable(admin.clone()).await; let mut pools = connections_sqlx().await; let pool = pools.swap_remove(1); - prepare_table(&pool).await; + // No PRIMARY KEY — both rows will get id=11 on shard 1 after the move. + prepare_table_no_pk(&pool).await; let insert_first = format!("INSERT INTO {TEST_TABLE} (id, value) VALUES (1, 'old')"); pool.execute(insert_first.as_str()) @@ -200,36 +201,27 @@ async fn update_rejects_multiple_rows() { let mut txn = pool.begin().await.unwrap(); let update = format!("UPDATE {TEST_TABLE} SET id = 11 WHERE id IN (1, 2)"); - let err = txn + let result = txn .execute(update.as_str()) .await - .expect_err("expected multi-row rewrite to fail"); - let db_err = err - .as_database_error() - .expect("expected database error from proxy"); - assert!( - db_err - .message() - .contains("sharding key update changes more than one row (2)"), - "unexpected error message: {}", - db_err.message() - ); - txn.rollback().await.unwrap(); + .expect("multi-row rewrite should succeed"); + assert_eq!(result.rows_affected(), 2, "both rows updated"); + txn.commit().await.unwrap(); assert_eq!( count_on_shard(&pool, 0, 1).await, - 1, - "row 1 still on shard 0" + 0, + "row 1 moved off shard 0" ); assert_eq!( count_on_shard(&pool, 0, 2).await, - 1, - "row 2 still on shard 0" + 0, + "row 2 moved off shard 0" ); assert_eq!( count_on_shard(&pool, 1, 11).await, - 0, - "no row inserted on shard 1" + 2, + "both rows on shard 1 with id=11" ); cleanup_table(&pool).await; @@ -363,6 +355,17 @@ async fn prepare_table(pool: &Pool) { } } +async fn prepare_table_no_pk(pool: &Pool) { + for shard in [0, 1] { + let drop = format!("/* pgdog_shard: {shard} */ DROP TABLE IF EXISTS {TEST_TABLE}"); + pool.execute(drop.as_str()).await.unwrap(); + let create = format!( + "/* pgdog_shard: {shard} */ CREATE TABLE {TEST_TABLE} (id BIGINT, value TEXT)" + ); + pool.execute(create.as_str()).await.unwrap(); + } +} + async fn cleanup_table(pool: &Pool) { for shard in [0, 1] { let drop = format!("/* pgdog_shard: {shard} */ DROP TABLE IF EXISTS {TEST_TABLE}"); diff --git a/pgdog/src/frontend/client/query_engine/multi_step/update.rs b/pgdog/src/frontend/client/query_engine/multi_step/update.rs index 255ae1fe8..8519e8130 100644 --- a/pgdog/src/frontend/client/query_engine/multi_step/update.rs +++ b/pgdog/src/frontend/client/query_engine/multi_step/update.rs @@ -98,6 +98,33 @@ impl<'a> UpdateMulti<'a> { context: &mut QueryEngineContext<'_>, rows: Rows, ) -> Result<(), Error> { + let original_shard = context.client_request.route().shard(); + + // Pre-build and route all INSERT requests upfront so we can + // check whether every row stays on the same shard before + // committing to the delete+insert workflow. + let mut requests = Vec::with_capacity(rows.data_rows.len()); + let mut all_same_shard = original_shard.is_direct(); + for data_row in &rows.data_rows { + let mut request = self.rewrite.insert.build_request( + context.client_request, + &rows.row_description, + data_row, + )?; + self.route(&mut request, context)?; + if request.route().shard() != original_shard { + all_same_shard = false; + } + requests.push(request); + } + + // If every row maps back to the original shard, forward the + // original UPDATE as-is — no delete+insert needed. + if all_same_shard { + debug!("[update] all rows are on the same shard"); + return self.execute_original(context).await; + } + debug!("[update] executing multi-shard insert/delete"); // Check if we are allowed to do this operation by the config. @@ -124,14 +151,8 @@ impl<'a> UpdateMulti<'a> { // Delete all matching rows from the original shard in one shot. self.delete_row(context).await?; - let n = rows.data_rows.len(); - for data_row in rows.data_rows { - let mut request = self.rewrite.insert.build_request( - context.client_request, - &rows.row_description, - &data_row, - )?; - self.route(&mut request, context)?; + let n = requests.len(); + for mut request in requests { self.execute_request_internal( context, &mut request, From 726925b8bf89a0c3926b86de16188f3f67222164 Mon Sep 17 00:00:00 2001 From: Yogesh Singla Date: Sat, 27 Jun 2026 00:22:45 +0530 Subject: [PATCH 3/3] fix cargo and clippy errors --- integration/rust/tests/integration/rewrite.rs | 5 ++--- pgdog/src/frontend/client/query_engine/multi_step/update.rs | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/integration/rust/tests/integration/rewrite.rs b/integration/rust/tests/integration/rewrite.rs index 8929ac564..2f9ad0089 100644 --- a/integration/rust/tests/integration/rewrite.rs +++ b/integration/rust/tests/integration/rewrite.rs @@ -359,9 +359,8 @@ async fn prepare_table_no_pk(pool: &Pool) { for shard in [0, 1] { let drop = format!("/* pgdog_shard: {shard} */ DROP TABLE IF EXISTS {TEST_TABLE}"); pool.execute(drop.as_str()).await.unwrap(); - let create = format!( - "/* pgdog_shard: {shard} */ CREATE TABLE {TEST_TABLE} (id BIGINT, value TEXT)" - ); + let create = + format!("/* pgdog_shard: {shard} */ CREATE TABLE {TEST_TABLE} (id BIGINT, value TEXT)"); pool.execute(create.as_str()).await.unwrap(); } } diff --git a/pgdog/src/frontend/client/query_engine/multi_step/update.rs b/pgdog/src/frontend/client/query_engine/multi_step/update.rs index 8519e8130..1774aab3b 100644 --- a/pgdog/src/frontend/client/query_engine/multi_step/update.rs +++ b/pgdog/src/frontend/client/query_engine/multi_step/update.rs @@ -164,7 +164,7 @@ impl<'a> UpdateMulti<'a> { self.engine .process_server_message( context, - CommandComplete::new(&format!("UPDATE {n}")).message()?, + CommandComplete::new(format!("UPDATE {n}")).message()?, ) .await?; self.engine