Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,14 @@ fn make_staggered_batches_i32(len: usize, with_extra_column: bool) -> Vec<Record
input12.sort_unstable();
let input1 = Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.0));
let input2 = Int32Array::from_iter_values(input12.clone().into_iter().map(|k| k.1));
let input3 = Int32Array::from_iter_values(input3);
let input3 = Int32Array::from_iter(input3.into_iter().map(|v| {
// ~10% NULLs in filter column to exercise NULL filter handling
if rng.random_range(0..10) == 0 {
None
} else {
Some(v)
}
}));
let input4 = Int32Array::from_iter_values(input4);

let mut columns = vec![
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1423,27 +1423,22 @@ impl MaterializingSortMergeJoinStream {
.evaluate(&filter_batch)?
.into_array(filter_batch.num_rows())?;

let pre_mask = datafusion_common::cast::as_boolean_array(&filter_result)?;
let filter_result_mask =
datafusion_common::cast::as_boolean_array(&filter_result)?;

let mask = if pre_mask.null_count() > 0 {
compute::prep_null_mask_filter(pre_mask)
// Convert NULL filter results to false — NULL means "not satisfied"
// per SQL semantics, same as Left/Right outer joins.
let mask = if filter_result_mask.null_count() > 0 {
compute::prep_null_mask_filter(filter_result_mask)
} else {
pre_mask.clone()
filter_result_mask.clone()
};

if needs_deferred_filtering(&self.filter, self.join_type) {
// Full join uses pre_mask (preserving nulls) for
// get_corrected_filter_mask; other outer joins use mask.
let mask_to_use = if self.join_type != JoinType::Full {
&mask
} else {
pre_mask
};

self.joined_record_batches.push_batch_with_filter_metadata(
output_batch,
&combined_left_indices,
mask_to_use,
&mask,
self.streamed_batch_counter.load(Relaxed),
self.join_type,
);
Expand All @@ -1468,7 +1463,7 @@ impl MaterializingSortMergeJoinStream {
let idx = right.value(i) as usize;
match buffered_batch.join_filter_status[idx] {
FilterState::SomePassed => {}
_ if pre_mask.value(offset + i) => {
_ if mask.value(offset + i) => {
buffered_batch.join_filter_status[idx] =
FilterState::SomePassed;
}
Expand Down
109 changes: 105 additions & 4 deletions datafusion/physical-plan/src/joins/sort_merge_join/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ use datafusion_common::{
test_util::{batches_to_sort_string, batches_to_string},
};
use datafusion_common::{
JoinType, NullEquality, Result, assert_batches_eq, assert_contains,
JoinType, NullEquality, Result, ScalarValue, assert_batches_eq, assert_contains,
};
use datafusion_common_runtime::JoinSet;
use datafusion_execution::config::SessionConfig;
Expand All @@ -65,6 +65,7 @@ use datafusion_execution::runtime_env::RuntimeEnvBuilder;
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::BinaryExpr;
use datafusion_physical_expr::expressions::Literal;
use datafusion_physical_expr_common::physical_expr::PhysicalExprRef;
use futures::{Stream, StreamExt};
use insta::assert_snapshot;
Expand Down Expand Up @@ -2049,6 +2050,108 @@ async fn join_full_multiple_batches() -> Result<()> {
Ok(())
}

/// Full outer join where the filter evaluates to NULL due to a nullable column.
/// NULL filter results must be treated as unmatched, not matched.
/// Reproducer for SPARK-43113.
#[tokio::test]
async fn join_full_null_filter_result() -> Result<()> {
// Left: (a, b) all non-null, sorted on a
let left = build_table_two_cols(
("a1", &vec![1, 1, 2, 2, 3, 3]),
("b1", &vec![1, 2, 1, 2, 1, 2]),
);

// Right: (a, b) with b nullable, sorted on a
let right_schema = Arc::new(Schema::new(vec![
Field::new("a2", DataType::Int32, false),
Field::new("b2", DataType::Int32, true),
]));
let right_batch = RecordBatch::try_new(
Arc::clone(&right_schema),
vec![
Arc::new(Int32Array::from(vec![1, 2])),
Arc::new(Int32Array::from(vec![None, Some(2)])),
],
)?;
let right =
TestMemoryExec::try_new_exec(&[vec![right_batch]], right_schema, None).unwrap();

let on = vec![(
Arc::new(Column::new_with_schema("a1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("a2", &right.schema())?) as _,
)];

// Filter: b1 < (b2 + 1) AND b1 < (a2 + 1)
// When b2 is NULL, (b2 + 1) is NULL, so b1 < NULL is NULL → unmatched.
let lit_1: PhysicalExprRef = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));
let b1_lt_b2_plus_1: PhysicalExprRef = Arc::new(BinaryExpr::new(
Arc::new(Column::new("b1", 0)),
Operator::Lt,
Arc::new(BinaryExpr::new(
Arc::new(Column::new("b2", 1)),
Operator::Plus,
Arc::clone(&lit_1),
)),
));
let b1_lt_a2_plus_1: PhysicalExprRef = Arc::new(BinaryExpr::new(
Arc::new(Column::new("b1", 0)),
Operator::Lt,
Arc::new(BinaryExpr::new(
Arc::new(Column::new("a2", 2)),
Operator::Plus,
Arc::clone(&lit_1),
)),
));
let filter_expr: PhysicalExprRef = Arc::new(BinaryExpr::new(
b1_lt_b2_plus_1,
Operator::And,
b1_lt_a2_plus_1,
));

let filter = JoinFilter::new(
filter_expr,
vec![
ColumnIndex {
index: 1,
side: JoinSide::Left,
},
ColumnIndex {
index: 1,
side: JoinSide::Right,
},
ColumnIndex {
index: 0,
side: JoinSide::Right,
},
],
Arc::new(Schema::new(vec![
Field::new("b1", DataType::Int32, true),
Field::new("b2", DataType::Int32, true),
Field::new("a2", DataType::Int32, true),
])),
);

let (_, batches) = join_collect_with_filter(left, right, on, filter, Full).await?;

// r=(1,NULL): b2 is NULL → b1 < (NULL+1) is NULL → all a=1 rows unmatched
// r=(2,2): b1 < 3 AND b1 < 3 → both l=(2,1) and l=(2,2) match
// l=(3,*): no right row with a=3 → unmatched
assert_snapshot!(batches_to_sort_string(&batches), @r"
+----+----+----+----+
| a1 | b1 | a2 | b2 |
+----+----+----+----+
| | | 1 | |
| 1 | 1 | | |
| 1 | 2 | | |
| 2 | 1 | 2 | 2 |
| 2 | 2 | 2 | 2 |
| 3 | 1 | | |
| 3 | 2 | | |
+----+----+----+----+
");
Ok(())
}

#[tokio::test]
async fn overallocation_single_batch_no_spill() -> Result<()> {
let left = build_table(
Expand Down Expand Up @@ -3589,9 +3692,7 @@ async fn join_filtered_with_multiple_buffered_batches() -> Result<()> {
Arc::new(Column::new("val_r", 1)),
)),
Operator::Lt,
Arc::new(datafusion_physical_expr::expressions::Literal::new(
datafusion_common::ScalarValue::Int32(Some(350)),
)),
Arc::new(Literal::new(ScalarValue::Int32(Some(350)))),
)),
vec![
ColumnIndex {
Expand Down
Loading