diff --git a/docs/docs/pypaimon/ray-data.md b/docs/docs/pypaimon/ray-data.md index d160b4302f65..a6d98f8fcce0 100644 --- a/docs/docs/pypaimon/ray-data.md +++ b/docs/docs/pypaimon/ray-data.md @@ -386,6 +386,13 @@ columns (`s.*`). Requires the `datafusion` package: `pip install pypaimon[sql]`. Use `lit()` for literals starting with `s.` or `t.`. - `condition`: an optional SQL-style boolean expression. Use `s.` and `t.` to reference source and target columns. +- Multiple clauses are evaluated in order; the first matching condition wins: + ```python + when_matched=[ + WhenMatched(update="*", condition="s.ts > t.ts"), + WhenMatched(update="*"), # fallback for unmatched rows + ] + ``` **Parameters:** - `source`: a `ray.data.Dataset`, `pyarrow.Table`, `pandas.DataFrame`, or a diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_into.py b/paimon-python/pypaimon/ray/data_evolution_merge_into.py index fa824b44a2f9..cbfcef907d81 100644 --- a/paimon-python/pypaimon/ray/data_evolution_merge_into.py +++ b/paimon-python/pypaimon/ray/data_evolution_merge_into.py @@ -93,12 +93,15 @@ def _prepare(target, source, catalog_options, when_matched, when_not_matched, on raise ValueError( "At least one of when_matched or when_not_matched must be non-empty." ) - if len(when_matched) > 1 or len(when_not_matched) > 1: - raise NotImplementedError( - "merge_into currently supports a single WhenMatched and a single " - "WhenNotMatched clause; multi-clause fall-through will be added " - "in a follow-up PR." - ) + for label, clauses in [("when_matched", when_matched), + ("when_not_matched", when_not_matched)]: + for i, clause in enumerate(clauses[:-1]): + if clause.condition is None: + raise ValueError( + f"Only the last {label} clause may omit its condition. " + f"Clause at index {i} has no condition, making subsequent " + f"clauses unreachable." + ) target_on_cols, source_on_cols = _normalize_on(on) from pypaimon.catalog.catalog_factory import CatalogFactory diff --git a/paimon-python/pypaimon/ray/data_evolution_merge_join.py b/paimon-python/pypaimon/ray/data_evolution_merge_join.py index f01f9b59aab8..14088979f893 100644 --- a/paimon-python/pypaimon/ray/data_evolution_merge_join.py +++ b/paimon-python/pypaimon/ray/data_evolution_merge_join.py @@ -87,43 +87,57 @@ def build_matched_update_ds( right_on=tuple(f"s.{c}" for c in source_on), ) - # MVP supports a single matched clause; future fan-out (conditions, multi- - # clause fall-through) must thread every clause's spec through the - # transform — guard so silent first-only behaviour can't sneak in. - assert len(clauses) == 1, ( - f"build_matched_update_ds expected 1 clause, got {len(clauses)}" - ) - spec = clauses[0].spec - condition = clauses[0].condition captured_update_cols = list(update_cols) captured_row_id_name = row_id_name captured_on_pairs = list(zip(source_on, target_on)) captured_schema = update_schema - captured_apply = None - captured_rewritten = None - if condition is not None: - from pypaimon.ray.merge_condition import ( - apply_condition, remap_source_on_keys, rewrite_condition, - ) - on_map = dict(zip(source_on, target_on)) - captured_rewritten = remap_source_on_keys( - rewrite_condition(condition), on_map, - ) - captured_apply = apply_condition + on_map = dict(zip(source_on, target_on)) + prepared_clauses = [] + for clause in clauses: + rewritten = None + if clause.condition is not None: + from pypaimon.ray.merge_condition import ( + remap_source_on_keys, rewrite_condition, + ) + rewritten = remap_source_on_keys( + rewrite_condition(clause.condition), on_map, + ) + prepared_clauses.append((clause.spec, rewritten)) + + _filter_batch = None + if any(r is not None for _, r in prepared_clauses): + from pypaimon.ray.merge_condition import filter_batch as _filter_batch def _transform(batch: pa.Table) -> pa.Table: - if captured_apply is not None: - batch = captured_apply( - batch, captured_rewritten, captured_schema, - ) - if batch.num_rows == 0: - return batch - return vectorized_matched_transform( - batch, spec, captured_on_pairs, - captured_update_cols, captured_row_id_name, - captured_schema, - ) + remaining = batch + parts = [] + for spec, rewritten in prepared_clauses: + if remaining.num_rows == 0: + break + if rewritten is not None: + matched = _filter_batch( + remaining, rewritten, _pre_rewritten=True, + ) + else: + matched = remaining + if matched.num_rows == 0: + continue + parts.append(vectorized_matched_transform( + matched, spec, captured_on_pairs, + captured_update_cols, captured_row_id_name, + captured_schema, + )) + if rewritten is not None and matched.num_rows < remaining.num_rows: + not_cond = f"COALESCE(NOT ({rewritten}), TRUE)" + remaining = _filter_batch( + remaining, not_cond, _pre_rewritten=True, + ) + else: + remaining = remaining.slice(0, 0) + if not parts: + return captured_schema.empty_table() + return pa.concat_tables(parts) return joined.map_batches(_transform, **_map_kwargs(ray_remote_args)) @@ -324,32 +338,47 @@ def build_not_matched_insert_ds( right_on=tuple(f"t.{c}" for c in target_on), ) - # MVP supports a single not-matched clause; see build_matched_update_ds - # for why we assert instead of silently dropping the rest. - assert len(clauses) == 1, ( - f"build_not_matched_insert_ds expected 1 clause, got {len(clauses)}" - ) - spec = clauses[0].spec - condition = clauses[0].condition - captured_apply = None - captured_rewritten = None - if condition is not None: - from pypaimon.ray.merge_condition import apply_condition, rewrite_condition - captured_rewritten = rewrite_condition(condition) - captured_apply = apply_condition + prepared_clauses = [] + for clause in clauses: + rewritten = None + if clause.condition is not None: + from pypaimon.ray.merge_condition import rewrite_condition + rewritten = rewrite_condition(clause.condition) + prepared_clauses.append((clause.spec, rewritten)) + + _filter_batch_nm = None + if any(r is not None for _, r in prepared_clauses): + from pypaimon.ray.merge_condition import filter_batch as _filter_batch_nm def _transform(batch: pa.Table) -> pa.Table: - if captured_apply is not None: - batch = captured_apply( - batch, captured_rewritten, out_schema, - ) - if batch.num_rows == 0: - return _coerce_large_string_types(batch) - return _coerce_large_string_types( - vectorized_insert_transform( - batch, spec, captured_field_names, out_schema - ) - ) + remaining = batch + parts = [] + for spec, rewritten in prepared_clauses: + if remaining.num_rows == 0: + break + if rewritten is not None: + matched = _filter_batch_nm( + remaining, rewritten, _pre_rewritten=True, + ) + if matched.num_rows > 0: + parts.append(vectorized_insert_transform( + matched, spec, captured_field_names, out_schema + )) + if matched.num_rows < remaining.num_rows: + not_cond = f"COALESCE(NOT ({rewritten}), TRUE)" + remaining = _filter_batch_nm( + remaining, not_cond, _pre_rewritten=True, + ) + else: + remaining = remaining.slice(0, 0) + else: + parts.append(vectorized_insert_transform( + remaining, spec, captured_field_names, out_schema + )) + remaining = remaining.slice(0, 0) + if not parts: + return _coerce_large_string_types(out_schema.empty_table()) + return _coerce_large_string_types(pa.concat_tables(parts)) return unmatched.map_batches( _transform, **_map_kwargs(ray_remote_args) diff --git a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py index ca06d43e5341..b54eeb5cf0c9 100644 --- a/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py +++ b/paimon-python/pypaimon/tests/ray_data_evolution_merge_into_test.py @@ -153,6 +153,40 @@ def test_no_clause_raises(self): num_partitions=_TEST_NUM_PARTITIONS, ) + def test_unconditional_non_last_matched_rejected(self): + target = self._create_table() + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched(update='*'), + WhenMatched(update={'age': 's.age'}, condition='s.age > 10'), + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('when_matched', str(ctx.exception)) + self.assertIn('unreachable', str(ctx.exception)) + + def test_unconditional_non_last_not_matched_rejected(self): + target = self._create_table() + with self.assertRaises(ValueError) as ctx: + merge_into( + target=target, + source=self._source(), + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[ + WhenNotMatched(insert='*'), + WhenNotMatched(insert='*', condition='s.age > 10'), + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('when_not_matched', str(ctx.exception)) + self.assertIn('unreachable', str(ctx.exception)) + def test_non_de_table_rejected(self): target = self._create_table(options={'row-tracking.enabled': 'true'}) with self.assertRaises(ValueError) as ctx: @@ -1315,6 +1349,305 @@ def test_target_col_helper(self): self.assertEqual(out['name'], ['keep']) self.assertEqual(out['age'], [99]) + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_multi_matched_clause_fall_through(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a2', 'b2', 'c2'], + 'age': pa.array([99, 88, 77], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched(update='*', condition='s.age > 80'), + WhenMatched(update='*'), + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a2', 'b2', 'c2']) + self.assertEqual(out['age'], [99, 88, 77]) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_multi_not_matched_clause_fall_through(self): + target = self._create_table() + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([25, 15, 5], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[ + WhenNotMatched(insert='*', condition='s.age >= 20'), + WhenNotMatched(insert='*'), + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_multi_matched_null_falls_through(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a', 'b', 'c'], + 'age': pa.array([10, 20, 30], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2, 3], type=pa.int32()), + 'name': ['a2', 'b2', 'c2'], + 'age': pa.array([None, 50, 60], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched(update='*', condition='s.age > 40'), + WhenMatched(update='*'), + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2, 3]) + self.assertEqual(out['name'], ['a2', 'b2', 'c2']) + self.assertEqual(out['age'], [None, 50, 60]) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_multi_not_matched_null_falls_through(self): + target = self._create_table() + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([None, 25], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_not_matched=[ + WhenNotMatched(insert='*', condition='s.age > 20'), + WhenNotMatched(insert='*'), + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['id'], [1, 2]) + self.assertEqual(out['age'], [None, 25]) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_multi_clause_no_match_skipped(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a', 'b'], + 'age': pa.array([10, 20], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 2], type=pa.int32()), + 'name': ['a2', 'b2'], + 'age': pa.array([5, 5], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched(update='*', condition='s.age > 50'), + WhenMatched(update='*', condition='s.age > 30'), + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['name'], ['a', 'b']) + self.assertEqual(out['age'], [10, 20]) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_multi_clause_first_wins(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['old'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['first'], + 'age': pa.array([99], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched(update={'name': 's.name'}, + condition='s.age > 50'), + WhenMatched(update={'age': 's.age'}, + condition='s.age > 10'), + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['name'], ['first']) + self.assertEqual(out['age'], [10]) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_multi_clause_duplicate_source_one_actionable(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['a'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 1], type=pa.int32()), + 'name': ['x', 'y'], + 'age': pa.array([99, 5], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched(update='*', condition='s.age > 50'), + WhenMatched(update='*', condition='s.age > 80'), + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + + out = self._read_sorted(target) + self.assertEqual(out['name'], ['x']) + self.assertEqual(out['age'], [99]) + + @unittest.skipIf(_SKIP_CONDITION, _SKIP_REASON) + def test_multi_clause_duplicate_both_actionable_raises(self): + target = self._create_table() + self._write( + target, + pa.Table.from_pydict( + { + 'id': pa.array([1], type=pa.int32()), + 'name': ['a'], + 'age': pa.array([10], type=pa.int32()), + }, + schema=self.pa_schema, + ), + ) + + source = pa.Table.from_pydict( + { + 'id': pa.array([1, 1], type=pa.int32()), + 'name': ['x', 'y'], + 'age': pa.array([99, 50], type=pa.int32()), + }, + schema=self.pa_schema, + ) + + with self.assertRaises(Exception) as ctx: + merge_into( + target=target, + source=source, + catalog_options=self.catalog_options, + on=['id'], + when_matched=[ + WhenMatched(update='*', condition='s.age > 80'), + WhenMatched(update='*', condition='s.age > 30'), + ], + num_partitions=_TEST_NUM_PARTITIONS, + ) + self.assertIn('multiple source rows', str(ctx.exception)) + class TargetProjectionTest(unittest.TestCase):