diff --git a/docs/docs/pypaimon/ray-data.md b/docs/docs/pypaimon/ray-data.md
index d160b4302f65..a6d98f8fcce0 100644
--- a/docs/docs/pypaimon/ray-data.md
+++ b/docs/docs/pypaimon/ray-data.md
@@ -386,6 +386,13 @@ columns (`s.*`). Requires the `datafusion` package: `pip install pypaimon[sql]`.
Use `lit()` for literals starting with `s.` or `t.`.
- `condition`: an optional SQL-style boolean expression. Use `s.
` and
`t.` to reference source and target columns.
+- Multiple clauses are evaluated in order; the first matching condition wins:
+ ```python
+ when_matched=[
+ WhenMatched(update="*", condition="s.ts > t.ts"),
+ WhenMatched(update="*"), # fallback for unmatched rows
+ ]
+ ```
**Parameters:**
- `source`: a `ray.data.Dataset`, `pyarrow.Table`, `pandas.DataFrame`, or a
diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_into.py b/paimon-python/pypaimon/ray/data_evolution_merge_into.py
index fa824b44a2f9..cbfcef907d81 100644
--- a/paimon-python/pypaimon/ray/data_evolution_merge_into.py
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_into.py
@@ -93,12 +93,15 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on
raise ValueError(
"At least one of when_matched or when_not_matched must be non-empty."
)
- if len(when_matched) > 1 or len(when_not_matched) > 1:
- raise NotImplementedError(
- "merge_into currently supports a single WhenMatched and a single "
- "WhenNotMatched clause; multi-clause fall-through will be added "
- "in a follow-up PR."
- )
+ for label, clauses in [("when_matched", when_matched),
+ ("when_not_matched", when_not_matched)]:
+ for i, clause in enumerate(clauses[:-1]):
+ if clause.condition is None:
+ raise ValueError(
+ f"Only the last {label} clause may omit its condition. "
+ f"Clause at index {i} has no condition, making subsequent "
+ f"clauses unreachable."
+ )
target_on_cols, source_on_cols = _normalize_on(on)
from pypaimon.catalog.catalog_factory import CatalogFactory
diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_join.py b/paimon-python/pypaimon/ray/data_evolution_merge_join.py
index f01f9b59aab8..14088979f893 100644
--- a/paimon-python/pypaimon/ray/data_evolution_merge_join.py
+++ b/paimon-python/pypaimon/ray/data_evolution_merge_join.py
@@ -87,43 +87,57 @@ def build_matched_update_ds(
right_on=tuple(f"s.{c}" for c in source_on),
)
- # MVP supports a single matched clause; future fan-out (conditions, multi-
- # clause fall-through) must thread every clause's spec through the
- # transform — guard so silent first-only behaviour can't sneak in.
- assert len(clauses) == 1, (
- f"build_matched_update_ds expected 1 clause, got {len(clauses)}"
- )
- spec = clauses[0].spec
- condition = clauses[0].condition
captured_update_cols = list(update_cols)
captured_row_id_name = row_id_name
captured_on_pairs = list(zip(source_on, target_on))
captured_schema = update_schema
- captured_apply = None
- captured_rewritten = None
- if condition is not None:
- from pypaimon.ray.merge_condition import (
- apply_condition, remap_source_on_keys, rewrite_condition,
- )
- on_map = dict(zip(source_on, target_on))
- captured_rewritten = remap_source_on_keys(
- rewrite_condition(condition), on_map,
- )
- captured_apply = apply_condition
+ on_map = dict(zip(source_on, target_on))
+ prepared_clauses = []
+ for clause in clauses:
+ rewritten = None
+ if clause.condition is not None:
+ from pypaimon.ray.merge_condition import (
+ remap_source_on_keys, rewrite_condition,
+ )
+ rewritten = remap_source_on_keys(
+ rewrite_condition(clause.condition), on_map,
+ )
+ prepared_clauses.append((clause.spec, rewritten))
+
+ _filter_batch = None
+ if any(r is not None for _, r in prepared_clauses):
+ from pypaimon.ray.merge_condition import filter_batch as _filter_batch
def _transform(batch: pa.Table) -> pa.Table:
- if captured_apply is not None:
- batch = captured_apply(
- batch, captured_rewritten, captured_schema,
- )
- if batch.num_rows == 0:
- return batch
- return vectorized_matched_transform(
- batch, spec, captured_on_pairs,
- captured_update_cols, captured_row_id_name,
- captured_schema,
- )
+ remaining = batch
+ parts = []
+ for spec, rewritten in prepared_clauses:
+ if remaining.num_rows == 0:
+ break
+ if rewritten is not None:
+ matched = _filter_batch(
+ remaining, rewritten, _pre_rewritten=True,
+ )
+ else:
+ matched = remaining
+ if matched.num_rows == 0:
+ continue
+ parts.append(vectorized_matched_transform(
+ matched, spec, captured_on_pairs,
+ captured_update_cols, captured_row_id_name,
+ captured_schema,
+ ))
+ if rewritten is not None and matched.num_rows < remaining.num_rows:
+ not_cond = f"COALESCE(NOT ({rewritten}), TRUE)"
+ remaining = _filter_batch(
+ remaining, not_cond, _pre_rewritten=True,
+ )
+ else:
+ remaining = remaining.slice(0, 0)
+ if not parts:
+ return captured_schema.empty_table()
+ return pa.concat_tables(parts)
return joined.map_batches(_transform, **_map_kwargs(ray_remote_args))
@@ -324,32 +338,47 @@ def build_not_matched_insert_ds(
right_on=tuple(f"t.{c}" for c in target_on),
)
- # MVP supports a single not-matched clause; see build_matched_update_ds
- # for why we assert instead of silently dropping the rest.
- assert len(clauses) == 1, (
- f"build_not_matched_insert_ds expected 1 clause, got {len(clauses)}"
- )
- spec = clauses[0].spec
- condition = clauses[0].condition
- captured_apply = None
- captured_rewritten = None
- if condition is not None:
- from pypaimon.ray.merge_condition import apply_condition, rewrite_condition
- captured_rewritten = rewrite_condition(condition)
- captured_apply = apply_condition
+ prepared_clauses = []
+ for clause in clauses:
+ rewritten = None
+ if clause.condition is not None:
+ from pypaimon.ray.merge_condition import rewrite_condition
+ rewritten = rewrite_condition(clause.condition)
+ prepared_clauses.append((clause.spec, rewritten))
+
+ _filter_batch_nm = None
+ if any(r is not None for _, r in prepared_clauses):
+ from pypaimon.ray.merge_condition import filter_batch as _filter_batch_nm
def _transform(batch: pa.Table) -> pa.Table:
- if captured_apply is not None:
- batch = captured_apply(
- batch, captured_rewritten, out_schema,
- )
- if batch.num_rows == 0:
- return _coerce_large_string_types(batch)
- return _coerce_large_string_types(
- vectorized_insert_transform(
- batch, spec, captured_field_names, out_schema
- )
- )
+ remaining = batch
+ parts = []
+ for spec, rewritten in prepared_clauses:
+ if remaining.num_rows == 0:
+ break
+ if rewritten is not None:
+ matched = _filter_batch_nm(
+ remaining, rewritten, _pre_rewritten=True,
+ )
+ if matched.num_rows > 0:
+ parts.append(vectorized_insert_transform(
+ matched, spec, captured_field_names, out_schema
+ ))
+ if matched.num_rows < remaining.num_rows:
+ not_cond = f"COALESCE(NOT ({rewritten}), TRUE)"
+ remaining = _filter_batch_nm(
+ remaining, not_cond, _pre_rewritten=True,
+ )
+ else:
+ remaining = remaining.slice(0, 0)
+ else:
+ parts.append(vectorized_insert_transform(
+ remaining, spec, captured_field_names, out_schema
+ ))
+ remaining = remaining.slice(0, 0)
+ if not parts:
+ return _coerce_large_string_types(out_schema.empty_table())
+ return _coerce_large_string_types(pa.concat_tables(parts))
return unmatched.map_batches(
_transform, **_map_kwargs(ray_remote_args)
diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
index ca06d43e5341..b54eeb5cf0c9 100644
--- a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
+++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py
@@ -153,6 +153,40 @@ def test_no_clause_raises(self):
num_partitions=_TEST_NUM_PARTITIONS,
)
+ def test_unconditional_non_last_matched_rejected(self):
+ target = self._create_table()
+ with self.assertRaises(ValueError) as ctx:
+ merge_into(
+ target=target,
+ source=self._source(),
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[
+ WhenMatched(update='*'),
+ WhenMatched(update={'age': 's.age'}, condition='s.age > 10'),
+ ],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+ self.assertIn('when_matched', str(ctx.exception))
+ self.assertIn('unreachable', str(ctx.exception))
+
+ def test_unconditional_non_last_not_matched_rejected(self):
+ target = self._create_table()
+ with self.assertRaises(ValueError) as ctx:
+ merge_into(
+ target=target,
+ source=self._source(),
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_not_matched=[
+ WhenNotMatched(insert='*'),
+ WhenNotMatched(insert='*', condition='s.age > 10'),
+ ],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+ self.assertIn('when_not_matched', str(ctx.exception))
+ self.assertIn('unreachable', str(ctx.exception))
+
def test_non_de_table_rejected(self):
target = self._create_table(options={'row-tracking.enabled': 'true'})
with self.assertRaises(ValueError) as ctx:
@@ -1315,6 +1349,305 @@ def test_target_col_helper(self):
self.assertEqual(out['name'], ['keep'])
self.assertEqual(out['age'], [99])
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_multi_matched_clause_fall_through(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a', 'b', 'c'],
+ 'age': pa.array([10, 20, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a2', 'b2', 'c2'],
+ 'age': pa.array([99, 88, 77], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[
+ WhenMatched(update='*', condition='s.age > 80'),
+ WhenMatched(update='*'),
+ ],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['name'], ['a2', 'b2', 'c2'])
+ self.assertEqual(out['age'], [99, 88, 77])
+
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_multi_not_matched_clause_fall_through(self):
+ target = self._create_table()
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a', 'b', 'c'],
+ 'age': pa.array([25, 15, 5], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_not_matched=[
+ WhenNotMatched(insert='*', condition='s.age >= 20'),
+ WhenNotMatched(insert='*'),
+ ],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_multi_matched_null_falls_through(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a', 'b', 'c'],
+ 'age': pa.array([10, 20, 30], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2, 3], type=pa.int32()),
+ 'name': ['a2', 'b2', 'c2'],
+ 'age': pa.array([None, 50, 60], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[
+ WhenMatched(update='*', condition='s.age > 40'),
+ WhenMatched(update='*'),
+ ],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2, 3])
+ self.assertEqual(out['name'], ['a2', 'b2', 'c2'])
+ self.assertEqual(out['age'], [None, 50, 60])
+
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_multi_not_matched_null_falls_through(self):
+ target = self._create_table()
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a', 'b'],
+ 'age': pa.array([None, 25], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_not_matched=[
+ WhenNotMatched(insert='*', condition='s.age > 20'),
+ WhenNotMatched(insert='*'),
+ ],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['id'], [1, 2])
+ self.assertEqual(out['age'], [None, 25])
+
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_multi_clause_no_match_skipped(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a', 'b'],
+ 'age': pa.array([10, 20], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 2], type=pa.int32()),
+ 'name': ['a2', 'b2'],
+ 'age': pa.array([5, 5], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[
+ WhenMatched(update='*', condition='s.age > 50'),
+ WhenMatched(update='*', condition='s.age > 30'),
+ ],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['name'], ['a', 'b'])
+ self.assertEqual(out['age'], [10, 20])
+
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_multi_clause_first_wins(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1], type=pa.int32()),
+ 'name': ['old'],
+ 'age': pa.array([10], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1], type=pa.int32()),
+ 'name': ['first'],
+ 'age': pa.array([99], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[
+ WhenMatched(update={'name': 's.name'},
+ condition='s.age > 50'),
+ WhenMatched(update={'age': 's.age'},
+ condition='s.age > 10'),
+ ],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['name'], ['first'])
+ self.assertEqual(out['age'], [10])
+
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_multi_clause_duplicate_source_one_actionable(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1], type=pa.int32()),
+ 'name': ['a'],
+ 'age': pa.array([10], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 1], type=pa.int32()),
+ 'name': ['x', 'y'],
+ 'age': pa.array([99, 5], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[
+ WhenMatched(update='*', condition='s.age > 50'),
+ WhenMatched(update='*', condition='s.age > 80'),
+ ],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+
+ out = self._read_sorted(target)
+ self.assertEqual(out['name'], ['x'])
+ self.assertEqual(out['age'], [99])
+
+ @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON)
+ def test_multi_clause_duplicate_both_actionable_raises(self):
+ target = self._create_table()
+ self._write(
+ target,
+ pa.Table.from_pydict(
+ {
+ 'id': pa.array([1], type=pa.int32()),
+ 'name': ['a'],
+ 'age': pa.array([10], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ ),
+ )
+
+ source = pa.Table.from_pydict(
+ {
+ 'id': pa.array([1, 1], type=pa.int32()),
+ 'name': ['x', 'y'],
+ 'age': pa.array([99, 50], type=pa.int32()),
+ },
+ schema=self.pa_schema,
+ )
+
+ with self.assertRaises(Exception) as ctx:
+ merge_into(
+ target=target,
+ source=source,
+ catalog_options=self.catalog_options,
+ on=['id'],
+ when_matched=[
+ WhenMatched(update='*', condition='s.age > 80'),
+ WhenMatched(update='*', condition='s.age > 30'),
+ ],
+ num_partitions=_TEST_NUM_PARTITIONS,
+ )
+ self.assertIn('multiple source rows', str(ctx.exception))
+
class TargetProjectionTest(unittest.TestCase):