@@ -40,6 +40,7 @@ use crate::projection::{
40
40
use crate :: spill:: get_record_batch_memory_size;
41
41
use crate :: ExecutionPlanProperties ;
42
42
use crate :: {
43
+ coalesce_partitions:: CoalescePartitionsExec ,
43
44
common:: can_project,
44
45
handle_state,
45
46
hash_utils:: create_hashes,
@@ -791,44 +792,34 @@ impl ExecutionPlan for HashJoinExec {
791
792
) ;
792
793
}
793
794
794
- if self . mode == PartitionMode :: CollectLeft && left_partitions != 1 {
795
- return internal_err ! (
796
- "Invalid HashJoinExec,the output partition count of the left child must be 1 in CollectLeft mode,\
797
- consider using CoalescePartitionsExec"
798
- ) ;
799
- }
800
-
801
795
let join_metrics = BuildProbeJoinMetrics :: new ( partition, & self . metrics ) ;
802
796
let left_fut = match self . mode {
803
- PartitionMode :: CollectLeft => {
804
- let left_stream = self . left . execute ( 0 , Arc :: clone ( & context) ) ?;
805
-
806
- self . left_fut . once ( || {
807
- let reservation = MemoryConsumer :: new ( "HashJoinInput" )
808
- . register ( context. memory_pool ( ) ) ;
809
-
810
- collect_left_input (
811
- self . random_state . clone ( ) ,
812
- left_stream,
813
- on_left. clone ( ) ,
814
- join_metrics. clone ( ) ,
815
- reservation,
816
- need_produce_result_in_final ( self . join_type ) ,
817
- self . right ( ) . output_partitioning ( ) . partition_count ( ) ,
818
- )
819
- } )
820
- }
797
+ PartitionMode :: CollectLeft => self . left_fut . once ( || {
798
+ let reservation =
799
+ MemoryConsumer :: new ( "HashJoinInput" ) . register ( context. memory_pool ( ) ) ;
800
+ collect_left_input (
801
+ None ,
802
+ self . random_state . clone ( ) ,
803
+ Arc :: clone ( & self . left ) ,
804
+ on_left. clone ( ) ,
805
+ Arc :: clone ( & context) ,
806
+ join_metrics. clone ( ) ,
807
+ reservation,
808
+ need_produce_result_in_final ( self . join_type ) ,
809
+ self . right ( ) . output_partitioning ( ) . partition_count ( ) ,
810
+ )
811
+ } ) ,
821
812
PartitionMode :: Partitioned => {
822
- let left_stream = self . left . execute ( partition, Arc :: clone ( & context) ) ?;
823
-
824
813
let reservation =
825
814
MemoryConsumer :: new ( format ! ( "HashJoinInput[{partition}]" ) )
826
815
. register ( context. memory_pool ( ) ) ;
827
816
828
817
OnceFut :: new ( collect_left_input (
818
+ Some ( partition) ,
829
819
self . random_state . clone ( ) ,
830
- left_stream ,
820
+ Arc :: clone ( & self . left ) ,
831
821
on_left. clone ( ) ,
822
+ Arc :: clone ( & context) ,
832
823
join_metrics. clone ( ) ,
833
824
reservation,
834
825
need_produce_result_in_final ( self . join_type ) ,
@@ -939,22 +930,36 @@ impl ExecutionPlan for HashJoinExec {
939
930
940
931
/// Reads the left (build) side of the input, buffering it in memory, to build a
941
932
/// hash table (`LeftJoinData`)
933
+ #[ allow( clippy:: too_many_arguments) ]
942
934
async fn collect_left_input (
935
+ partition : Option < usize > ,
943
936
random_state : RandomState ,
944
- left_stream : SendableRecordBatchStream ,
937
+ left : Arc < dyn ExecutionPlan > ,
945
938
on_left : Vec < PhysicalExprRef > ,
939
+ context : Arc < TaskContext > ,
946
940
metrics : BuildProbeJoinMetrics ,
947
941
reservation : MemoryReservation ,
948
942
with_visited_indices_bitmap : bool ,
949
943
probe_threads_count : usize ,
950
944
) -> Result < JoinLeftData > {
951
- let schema = left_stream. schema ( ) ;
945
+ let schema = left. schema ( ) ;
946
+
947
+ let ( left_input, left_input_partition) = if let Some ( partition) = partition {
948
+ ( left, partition)
949
+ } else if left. output_partitioning ( ) . partition_count ( ) != 1 {
950
+ ( Arc :: new ( CoalescePartitionsExec :: new ( left) ) as _ , 0 )
951
+ } else {
952
+ ( left, 0 )
953
+ } ;
954
+
955
+ // Depending on partition argument load single partition or whole left side in memory
956
+ let stream = left_input. execute ( left_input_partition, Arc :: clone ( & context) ) ?;
952
957
953
958
// This operation performs 2 steps at once:
954
959
// 1. creates a [JoinHashMap] of all batches from the stream
955
960
// 2. stores the batches in a vector.
956
961
let initial = ( Vec :: new ( ) , 0 , metrics, reservation) ;
957
- let ( batches, num_rows, metrics, mut reservation) = left_stream
962
+ let ( batches, num_rows, metrics, mut reservation) = stream
958
963
. try_fold ( initial, |mut acc, batch| async {
959
964
let batch_size = get_record_batch_memory_size ( & batch) ;
960
965
// Reserve memory for incoming batch
@@ -1650,7 +1655,6 @@ impl EmbeddedProjection for HashJoinExec {
1650
1655
#[ cfg( test) ]
1651
1656
mod tests {
1652
1657
use super :: * ;
1653
- use crate :: coalesce_partitions:: CoalescePartitionsExec ;
1654
1658
use crate :: test:: TestMemoryExec ;
1655
1659
use crate :: {
1656
1660
common, expressions:: Column , repartition:: RepartitionExec , test:: build_table_i32,
@@ -2101,7 +2105,6 @@ mod tests {
2101
2105
let left =
2102
2106
TestMemoryExec :: try_new_exec ( & [ vec ! [ batch1] , vec ! [ batch2] ] , schema, None )
2103
2107
. unwrap ( ) ;
2104
- let left = Arc :: new ( CoalescePartitionsExec :: new ( left) ) ;
2105
2108
2106
2109
let right = build_table (
2107
2110
( "a1" , & vec ! [ 1 , 2 , 3 ] ) ,
@@ -2174,7 +2177,6 @@ mod tests {
2174
2177
let left =
2175
2178
TestMemoryExec :: try_new_exec ( & [ vec ! [ batch1] , vec ! [ batch2] ] , schema, None )
2176
2179
. unwrap ( ) ;
2177
- let left = Arc :: new ( CoalescePartitionsExec :: new ( left) ) ;
2178
2180
let right = build_table (
2179
2181
( "a2" , & vec ! [ 20 , 30 , 10 ] ) ,
2180
2182
( "b2" , & vec ! [ 5 , 6 , 4 ] ) ,
0 commit comments