Skip to content

Commit dc8d119

Browse files
authored
Revert "Remove CoalescePartitions insertion from HashJoinExec (#15476)" (#15496)
This reverts commit 7e0738a.
1 parent 102f879 commit dc8d119

File tree

1 file changed

+36
-34
lines changed

1 file changed

+36
-34
lines changed

datafusion/physical-plan/src/joins/hash_join.rs

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ use crate::projection::{
4040
use crate::spill::get_record_batch_memory_size;
4141
use crate::ExecutionPlanProperties;
4242
use crate::{
43+
coalesce_partitions::CoalescePartitionsExec,
4344
common::can_project,
4445
handle_state,
4546
hash_utils::create_hashes,
@@ -791,44 +792,34 @@ impl ExecutionPlan for HashJoinExec {
791792
);
792793
}
793794

794-
if self.mode == PartitionMode::CollectLeft && left_partitions != 1 {
795-
return internal_err!(
796-
"Invalid HashJoinExec,the output partition count of the left child must be 1 in CollectLeft mode,\
797-
consider using CoalescePartitionsExec"
798-
);
799-
}
800-
801795
let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
802796
let left_fut = match self.mode {
803-
PartitionMode::CollectLeft => {
804-
let left_stream = self.left.execute(0, Arc::clone(&context))?;
805-
806-
self.left_fut.once(|| {
807-
let reservation = MemoryConsumer::new("HashJoinInput")
808-
.register(context.memory_pool());
809-
810-
collect_left_input(
811-
self.random_state.clone(),
812-
left_stream,
813-
on_left.clone(),
814-
join_metrics.clone(),
815-
reservation,
816-
need_produce_result_in_final(self.join_type),
817-
self.right().output_partitioning().partition_count(),
818-
)
819-
})
820-
}
797+
PartitionMode::CollectLeft => self.left_fut.once(|| {
798+
let reservation =
799+
MemoryConsumer::new("HashJoinInput").register(context.memory_pool());
800+
collect_left_input(
801+
None,
802+
self.random_state.clone(),
803+
Arc::clone(&self.left),
804+
on_left.clone(),
805+
Arc::clone(&context),
806+
join_metrics.clone(),
807+
reservation,
808+
need_produce_result_in_final(self.join_type),
809+
self.right().output_partitioning().partition_count(),
810+
)
811+
}),
821812
PartitionMode::Partitioned => {
822-
let left_stream = self.left.execute(partition, Arc::clone(&context))?;
823-
824813
let reservation =
825814
MemoryConsumer::new(format!("HashJoinInput[{partition}]"))
826815
.register(context.memory_pool());
827816

828817
OnceFut::new(collect_left_input(
818+
Some(partition),
829819
self.random_state.clone(),
830-
left_stream,
820+
Arc::clone(&self.left),
831821
on_left.clone(),
822+
Arc::clone(&context),
832823
join_metrics.clone(),
833824
reservation,
834825
need_produce_result_in_final(self.join_type),
@@ -939,22 +930,36 @@ impl ExecutionPlan for HashJoinExec {
939930

940931
/// Reads the left (build) side of the input, buffering it in memory, to build a
941932
/// hash table (`LeftJoinData`)
933+
#[allow(clippy::too_many_arguments)]
942934
async fn collect_left_input(
935+
partition: Option<usize>,
943936
random_state: RandomState,
944-
left_stream: SendableRecordBatchStream,
937+
left: Arc<dyn ExecutionPlan>,
945938
on_left: Vec<PhysicalExprRef>,
939+
context: Arc<TaskContext>,
946940
metrics: BuildProbeJoinMetrics,
947941
reservation: MemoryReservation,
948942
with_visited_indices_bitmap: bool,
949943
probe_threads_count: usize,
950944
) -> Result<JoinLeftData> {
951-
let schema = left_stream.schema();
945+
let schema = left.schema();
946+
947+
let (left_input, left_input_partition) = if let Some(partition) = partition {
948+
(left, partition)
949+
} else if left.output_partitioning().partition_count() != 1 {
950+
(Arc::new(CoalescePartitionsExec::new(left)) as _, 0)
951+
} else {
952+
(left, 0)
953+
};
954+
955+
// Depending on partition argument load single partition or whole left side in memory
956+
let stream = left_input.execute(left_input_partition, Arc::clone(&context))?;
952957

953958
// This operation performs 2 steps at once:
954959
// 1. creates a [JoinHashMap] of all batches from the stream
955960
// 2. stores the batches in a vector.
956961
let initial = (Vec::new(), 0, metrics, reservation);
957-
let (batches, num_rows, metrics, mut reservation) = left_stream
962+
let (batches, num_rows, metrics, mut reservation) = stream
958963
.try_fold(initial, |mut acc, batch| async {
959964
let batch_size = get_record_batch_memory_size(&batch);
960965
// Reserve memory for incoming batch
@@ -1650,7 +1655,6 @@ impl EmbeddedProjection for HashJoinExec {
16501655
#[cfg(test)]
16511656
mod tests {
16521657
use super::*;
1653-
use crate::coalesce_partitions::CoalescePartitionsExec;
16541658
use crate::test::TestMemoryExec;
16551659
use crate::{
16561660
common, expressions::Column, repartition::RepartitionExec, test::build_table_i32,
@@ -2101,7 +2105,6 @@ mod tests {
21012105
let left =
21022106
TestMemoryExec::try_new_exec(&[vec![batch1], vec![batch2]], schema, None)
21032107
.unwrap();
2104-
let left = Arc::new(CoalescePartitionsExec::new(left));
21052108

21062109
let right = build_table(
21072110
("a1", &vec![1, 2, 3]),
@@ -2174,7 +2177,6 @@ mod tests {
21742177
let left =
21752178
TestMemoryExec::try_new_exec(&[vec![batch1], vec![batch2]], schema, None)
21762179
.unwrap();
2177-
let left = Arc::new(CoalescePartitionsExec::new(left));
21782180
let right = build_table(
21792181
("a2", &vec![20, 30, 10]),
21802182
("b2", &vec![5, 6, 4]),

0 commit comments

Comments
 (0)