Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions integration/rust/tests/integration/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,15 @@ async fn update_moves_row_between_shards() {
}

#[tokio::test]
async fn update_rejects_multiple_rows() {
async fn update_multiple_rows() {
let admin = admin_sqlx().await;
let _guard = RewriteConfigGuard::enable(admin.clone()).await;

let mut pools = connections_sqlx().await;
let pool = pools.swap_remove(1);

prepare_table(&pool).await;
// No PRIMARY KEY — both rows will get id=11 on shard 1 after the move.
prepare_table_no_pk(&pool).await;

let insert_first = format!("INSERT INTO {TEST_TABLE} (id, value) VALUES (1, 'old')");
pool.execute(insert_first.as_str())
Expand All @@ -200,36 +201,27 @@ async fn update_rejects_multiple_rows() {
let mut txn = pool.begin().await.unwrap();

let update = format!("UPDATE {TEST_TABLE} SET id = 11 WHERE id IN (1, 2)");
let err = txn
let result = txn
.execute(update.as_str())
.await
.expect_err("expected multi-row rewrite to fail");
let db_err = err
.as_database_error()
.expect("expected database error from proxy");
assert!(
db_err
.message()
.contains("sharding key update changes more than one row (2)"),
"unexpected error message: {}",
db_err.message()
);
txn.rollback().await.unwrap();
.expect("multi-row rewrite should succeed");
assert_eq!(result.rows_affected(), 2, "both rows updated");
txn.commit().await.unwrap();

assert_eq!(
count_on_shard(&pool, 0, 1).await,
1,
"row 1 still on shard 0"
0,
"row 1 moved off shard 0"
);
assert_eq!(
count_on_shard(&pool, 0, 2).await,
1,
"row 2 still on shard 0"
0,
"row 2 moved off shard 0"
);
assert_eq!(
count_on_shard(&pool, 1, 11).await,
0,
"no row inserted on shard 1"
2,
"both rows on shard 1 with id=11"
);

cleanup_table(&pool).await;
Expand Down Expand Up @@ -363,6 +355,16 @@ async fn prepare_table(pool: &Pool<Postgres>) {
}
}

async fn prepare_table_no_pk(pool: &Pool<Postgres>) {
for shard in [0, 1] {
let drop = format!("/* pgdog_shard: {shard} */ DROP TABLE IF EXISTS {TEST_TABLE}");
pool.execute(drop.as_str()).await.unwrap();
let create =
format!("/* pgdog_shard: {shard} */ CREATE TABLE {TEST_TABLE} (id BIGINT, value TEXT)");
pool.execute(create.as_str()).await.unwrap();
}
}

async fn cleanup_table(pool: &Pool<Postgres>) {
for shard in [0, 1] {
let drop = format!("/* pgdog_shard: {shard} */ DROP TABLE IF EXISTS {TEST_TABLE}");
Expand Down
150 changes: 79 additions & 71 deletions pgdog/src/frontend/client/query_engine/multi_step/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ use crate::{
use super::{Error, ForwardCheck, UpdateError};

#[derive(Debug, Clone, Default)]
pub(super) struct Row {
data_row: DataRow,
pub(super) struct Rows {
row_description: RowDescription,
data_rows: Vec<DataRow>,
}

#[derive(Debug)]
Expand Down Expand Up @@ -77,11 +77,11 @@ impl<'a> UpdateMulti<'a> {
return Ok(());
}

// Fetch the old row from whatever shard it is on.
let row = self.fetch_row(context).await?;
// Fetch the old rows from whatever shard(s) they are on.
let rows = self.fetch_rows(context).await?;

if let Some(row) = row {
self.insert_row(context, row).await?;
if let Some(rows) = rows {
self.insert_rows(context, rows).await?;
} else {
// This happens, but the UPDATE's WHERE clause
// doesn't match any rows, so this whole thing is a no-op.
Expand All @@ -93,72 +93,88 @@ impl<'a> UpdateMulti<'a> {
Ok(())
}

/// Create row.
pub(super) async fn insert_row(
pub(super) async fn insert_rows(
&mut self,
context: &mut QueryEngineContext<'_>,
row: Row,
rows: Rows,
) -> Result<(), Error> {
let mut request = self.rewrite.insert.build_request(
context.client_request,
&row.row_description,
&row.data_row,
)?;
self.route(&mut request, context)?;

let original_shard = context.client_request.route().shard();
let new_shard = request.route().shard();

// The new row maps to the same shard as the old row.
// We don't need to do the multi-step UPDATE anymore.
// Forward the original request as-is.
if original_shard.is_direct() && new_shard == original_shard {
debug!("[update] selected row is on the same shard");
self.execute_original(context).await
} else {
debug!("[update] executing multi-shard insert/delete");

// Check if we are allowed to do this operation by the config.
if self.engine.backend.cluster()?.rewrite().shard_key == RewriteMode::Error {
self.engine
.error_response(context, ErrorResponse::from_err(&UpdateError::Disabled))
.await?;
return Ok(());
}

if !context.in_transaction() || !self.engine.backend.is_multishard()
// Do this check at the last possible moment.
// Just in case we change how transactions are
// routed in the future.
{
self.engine.cleanup_backend(context)?;
return Err(UpdateError::TransactionRequired.into());
// Pre-build and route all INSERT requests upfront so we can
// check whether every row stays on the same shard before
// committing to the delete+insert workflow.
let mut requests = Vec::with_capacity(rows.data_rows.len());
let mut all_same_shard = original_shard.is_direct();
for data_row in &rows.data_rows {
let mut request = self.rewrite.insert.build_request(
context.client_request,
&rows.row_description,
data_row,
)?;
self.route(&mut request, context)?;
if request.route().shard() != original_shard {
all_same_shard = false;
}
requests.push(request);
}

if self.has_destructive_on_delete_reference(context)? {
return Err(UpdateError::ForeignKeyOnDelete.into());
}
// If every row maps back to the original shard, forward the
// original UPDATE as-is — no delete+insert needed.
if all_same_shard {
debug!("[update] all rows are on the same shard");
return self.execute_original(context).await;
}

debug!("[update] executing multi-shard insert/delete");

// Check if we are allowed to do this operation by the config.
if self.engine.backend.cluster()?.rewrite().shard_key == RewriteMode::Error {
self.engine
.error_response(context, ErrorResponse::from_err(&UpdateError::Disabled))
.await?;
return Ok(());
}

if !context.in_transaction() || !self.engine.backend.is_multishard()
// Do this check at the last possible moment.
// Just in case we change how transactions are
// routed in the future.
{
self.engine.cleanup_backend(context)?;
return Err(UpdateError::TransactionRequired.into());
}

if self.has_destructive_on_delete_reference(context)? {
return Err(UpdateError::ForeignKeyOnDelete.into());
}

self.delete_row(context).await?;
// Delete all matching rows from the original shard in one shot.
self.delete_row(context).await?;

let n = requests.len();
for mut request in requests {
self.execute_request_internal(
context,
&mut request,
self.rewrite.insert.is_returning(),
)
.await?;
}

self.engine
.process_server_message(context, CommandComplete::new("UPDATE 1").message()?) // We only allow to update one row at a time.
.await?;
self.engine
.process_server_message(
context,
ReadyForQuery::in_transaction(context.in_transaction()).message()?,
)
.await?;
self.engine
.process_server_message(
context,
CommandComplete::new(format!("UPDATE {n}")).message()?,
)
.await?;
self.engine
.process_server_message(
context,
ReadyForQuery::in_transaction(context.in_transaction()).message()?,
)
.await?;

Ok(())
}
Ok(())
}

fn has_destructive_on_delete_reference(
Expand Down Expand Up @@ -249,10 +265,10 @@ impl<'a> UpdateMulti<'a> {
.await
}

pub(super) async fn fetch_row(
pub(super) async fn fetch_rows(
&mut self,
context: &mut QueryEngineContext<'_>,
) -> Result<Option<Row>, Error> {
) -> Result<Option<Rows>, Error> {
let mut request = self.rewrite.select.build_request(context.client_request)?;
self.route(&mut request, context)?;

Expand All @@ -261,29 +277,21 @@ impl<'a> UpdateMulti<'a> {
.handle_client_request(&request, &mut Router::default(), false)
.await?;

let mut row = Row::default();
let mut rows = 0;
let mut rows = Rows::default();

while self.engine.backend.has_more_messages() {
let message = self.engine.read_server_message().await?;
match message.code() {
'D' => {
row.data_row = DataRow::try_from(message)?;
rows += 1;
}
'T' => row.row_description = RowDescription::try_from(message)?,
'D' => rows.data_rows.push(DataRow::try_from(message)?),
'T' => rows.row_description = RowDescription::try_from(message)?,
'E' => return Err(ErrorResponse::try_from(message)?.into()),
_ => (),
}
}

match rows {
0 => return Ok(None),
1 => (),
n => return Err(UpdateError::TooManyRows(n).into()),
}
debug!("[update] found {} rows to move", rows.data_rows.len());

Ok(Some(row))
Ok((!rows.data_rows.is_empty()).then_some(rows))
}

/// Returns true if the new sharding key resides on the same shard
Expand Down
Loading