Skip to content

Commit 2c19116

Browse files
committed
maintaining fifo hashmap in hash join
1 parent 6f5230f commit 2c19116

File tree

3 files changed

+141
-78
lines changed

3 files changed

+141
-78
lines changed

datafusion/physical-plan/src/joins/hash_join.rs

Lines changed: 61 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ use crate::joins::utils::{
2929
need_produce_result_in_final, JoinHashMap, JoinHashMapType,
3030
};
3131
use crate::{
32-
coalesce_batches::concat_batches,
3332
coalesce_partitions::CoalescePartitionsExec,
3433
expressions::Column,
3534
expressions::PhysicalSortExpr,
@@ -52,10 +51,10 @@ use super::{
5251

5352
use arrow::array::{
5453
Array, ArrayRef, BooleanArray, BooleanBufferBuilder, PrimitiveArray, UInt32Array,
55-
UInt32BufferBuilder, UInt64Array, UInt64BufferBuilder,
54+
UInt64Array,
5655
};
5756
use arrow::compute::kernels::cmp::{eq, not_distinct};
58-
use arrow::compute::{and, take, FilterBuilder};
57+
use arrow::compute::{and, concat_batches, take, FilterBuilder};
5958
use arrow::datatypes::{Schema, SchemaRef};
6059
use arrow::record_batch::RecordBatch;
6160
use arrow::util::bit_util;
@@ -715,7 +714,10 @@ async fn collect_left_input(
715714
let mut hashmap = JoinHashMap::with_capacity(num_rows);
716715
let mut hashes_buffer = Vec::new();
717716
let mut offset = 0;
718-
for batch in batches.iter() {
717+
718+
// Reverse iteration over build-side input batches allows to create FIFO hashmap
719+
let batches_iter = batches.iter().rev();
720+
for batch in batches_iter.clone() {
719721
hashes_buffer.clear();
720722
hashes_buffer.resize(batch.num_rows(), 0);
721723
update_hash(
@@ -726,19 +728,25 @@ async fn collect_left_input(
726728
&random_state,
727729
&mut hashes_buffer,
728730
0,
731+
true,
729732
)?;
730733
offset += batch.num_rows();
731734
}
732735
// Merge all batches into a single batch, so we
733736
// can directly index into the arrays
734-
let single_batch = concat_batches(&schema, &batches, num_rows)?;
737+
let single_batch = concat_batches(&schema, batches_iter)?;
735738
let data = JoinLeftData::new(hashmap, single_batch, reservation);
736739

737740
Ok(data)
738741
}
739742

740-
/// Updates `hash` with new entries from [RecordBatch] evaluated against the expressions `on`,
741-
/// assuming that the [RecordBatch] corresponds to the `index`th
743+
/// Updates `hash_map` with new entries from `batch` evaluated against the expressions `on`
744+
/// using `offset` as a start value for `batch` row indices.
745+
///
746+
/// `fifo_hashmap` sets the order of iteration over `batch` rows while updating hashmap,
747+
/// which allows to keep either first (if set to true) or last (if set to false) row index
748+
/// as a chain head for matching hashes.
749+
#[allow(clippy::too_many_arguments)]
742750
pub fn update_hash<T>(
743751
on: &[Column],
744752
batch: &RecordBatch,
@@ -747,6 +755,7 @@ pub fn update_hash<T>(
747755
random_state: &RandomState,
748756
hashes_buffer: &mut Vec<u64>,
749757
deleted_offset: usize,
758+
fifo_hashmap: bool,
750759
) -> Result<()>
751760
where
752761
T: JoinHashMapType,
@@ -763,28 +772,18 @@ where
763772
// For usual JoinHashmap, the implementation is void.
764773
hash_map.extend_zero(batch.num_rows());
765774

766-
// insert hashes to key of the hashmap
767-
let (mut_map, mut_list) = hash_map.get_mut();
768-
for (row, hash_value) in hash_values.iter().enumerate() {
769-
let item = mut_map.get_mut(*hash_value, |(hash, _)| *hash_value == *hash);
770-
if let Some((_, index)) = item {
771-
// Already exists: add index to next array
772-
let prev_index = *index;
773-
// Store new value inside hashmap
774-
*index = (row + offset + 1) as u64;
775-
// Update chained Vec at row + offset with previous value
776-
mut_list[row + offset - deleted_offset] = prev_index;
777-
} else {
778-
mut_map.insert(
779-
*hash_value,
780-
// store the value + 1 as 0 value reserved for end of list
781-
(*hash_value, (row + offset + 1) as u64),
782-
|(hash, _)| *hash,
783-
);
784-
// chained list at (row + offset) is already initialized with 0
785-
// meaning end of list
786-
}
775+
// Updating JoinHashMap from hash values iterator
776+
let hash_values_iter = hash_values
777+
.iter()
778+
.enumerate()
779+
.map(|(i, val)| (i + offset, val));
780+
781+
if fifo_hashmap {
782+
hash_map.update_from_iter(hash_values_iter.rev(), deleted_offset);
783+
} else {
784+
hash_map.update_from_iter(hash_values_iter, deleted_offset);
787785
}
786+
788787
Ok(())
789788
}
790789

@@ -987,6 +986,7 @@ pub fn build_equal_condition_join_indices<T: JoinHashMapType>(
987986
filter: Option<&JoinFilter>,
988987
build_side: JoinSide,
989988
deleted_offset: Option<usize>,
989+
fifo_hashmap: bool,
990990
) -> Result<(UInt64Array, UInt32Array)> {
991991
let keys_values = probe_on
992992
.iter()
@@ -1002,10 +1002,9 @@ pub fn build_equal_condition_join_indices<T: JoinHashMapType>(
10021002
hashes_buffer.clear();
10031003
hashes_buffer.resize(probe_batch.num_rows(), 0);
10041004
let hash_values = create_hashes(&keys_values, random_state, hashes_buffer)?;
1005-
// Using a buffer builder to avoid slower normal builder
1006-
let mut build_indices = UInt64BufferBuilder::new(0);
1007-
let mut probe_indices = UInt32BufferBuilder::new(0);
1008-
// The chained list algorithm generates build indices for each probe row in a reversed sequence as such:
1005+
1006+
// In case build-side input has not been inverted while JoinHashMap creation, the chained list algorithm
1007+
// will return build indices for each probe row in a reverse order:
10091008
// Build Indices: [5, 4, 3]
10101009
// Probe Indices: [1, 1, 1]
10111010
//
@@ -1034,44 +1033,17 @@ pub fn build_equal_condition_join_indices<T: JoinHashMapType>(
10341033
// (5,1)
10351034
//
10361035
// With this approach, the lexicographic order on both the probe side and the build side is preserved.
1037-
let hash_map = build_hashmap.get_map();
1038-
let next_chain = build_hashmap.get_list();
1039-
for (row, hash_value) in hash_values.iter().enumerate().rev() {
1040-
// Get the hash and find it in the build index
1041-
1042-
// For every item on the build and probe we check if it matches
1043-
// This possibly contains rows with hash collisions,
1044-
// So we have to check here whether rows are equal or not
1045-
if let Some((_, index)) =
1046-
hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash)
1047-
{
1048-
let mut i = *index - 1;
1049-
loop {
1050-
let build_row_value = if let Some(offset) = deleted_offset {
1051-
// This arguments means that we prune the next index way before here.
1052-
if i < offset as u64 {
1053-
// End of the list due to pruning
1054-
break;
1055-
}
1056-
i - offset as u64
1057-
} else {
1058-
i
1059-
};
1060-
build_indices.append(build_row_value);
1061-
probe_indices.append(row as u32);
1062-
// Follow the chain to get the next index value
1063-
let next = next_chain[build_row_value as usize];
1064-
if next == 0 {
1065-
// end of list
1066-
break;
1067-
}
1068-
i = next - 1;
1069-
}
1070-
}
1071-
}
1072-
// Reversing both sets of indices
1073-
build_indices.as_slice_mut().reverse();
1074-
probe_indices.as_slice_mut().reverse();
1036+
let (mut build_indices, mut probe_indices) = if fifo_hashmap {
1037+
build_hashmap.get_matched_indices(hash_values.iter().enumerate(), deleted_offset)
1038+
} else {
1039+
let (mut matched_build, mut matched_probe) = build_hashmap
1040+
.get_matched_indices(hash_values.iter().enumerate().rev(), deleted_offset);
1041+
1042+
matched_build.as_slice_mut().reverse();
1043+
matched_probe.as_slice_mut().reverse();
1044+
1045+
(matched_build, matched_probe)
1046+
};
10751047

10761048
let left: UInt64Array = PrimitiveArray::new(build_indices.finish().into(), None);
10771049
let right: UInt32Array = PrimitiveArray::new(probe_indices.finish().into(), None);
@@ -1279,6 +1251,7 @@ impl HashJoinStream {
12791251
self.filter.as_ref(),
12801252
JoinSide::Left,
12811253
None,
1254+
true,
12821255
);
12831256

12841257
let result = match left_right_indices {
@@ -1393,7 +1366,9 @@ mod tests {
13931366

13941367
use arrow::array::{ArrayRef, Date32Array, Int32Array, UInt32Builder, UInt64Builder};
13951368
use arrow::datatypes::{DataType, Field, Schema};
1396-
use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue};
1369+
use datafusion_common::{
1370+
assert_batches_eq, assert_batches_sorted_eq, assert_contains, ScalarValue,
1371+
};
13971372
use datafusion_execution::config::SessionConfig;
13981373
use datafusion_execution::runtime_env::{RuntimeConfig, RuntimeEnv};
13991374
use datafusion_expr::Operator;
@@ -1558,7 +1533,9 @@ mod tests {
15581533
"| 3 | 5 | 9 | 20 | 5 | 80 |",
15591534
"+----+----+----+----+----+----+",
15601535
];
1561-
assert_batches_sorted_eq!(expected, &batches);
1536+
1537+
// Inner join output is expected to preserve both inputs order
1538+
assert_batches_eq!(expected, &batches);
15621539

15631540
Ok(())
15641541
}
@@ -1640,7 +1617,8 @@ mod tests {
16401617
"+----+----+----+----+----+----+",
16411618
];
16421619

1643-
assert_batches_sorted_eq!(expected, &batches);
1620+
// Inner join output is expected to preserve both inputs order
1621+
assert_batches_eq!(expected, &batches);
16441622

16451623
Ok(())
16461624
}
@@ -1686,7 +1664,8 @@ mod tests {
16861664
"+----+----+----+----+----+----+",
16871665
];
16881666

1689-
assert_batches_sorted_eq!(expected, &batches);
1667+
// Inner join output is expected to preserve both inputs order
1668+
assert_batches_eq!(expected, &batches);
16901669

16911670
Ok(())
16921671
}
@@ -1740,7 +1719,8 @@ mod tests {
17401719
"+----+----+----+----+----+----+",
17411720
];
17421721

1743-
assert_batches_sorted_eq!(expected, &batches);
1722+
// Inner join output is expected to preserve both inputs order
1723+
assert_batches_eq!(expected, &batches);
17441724

17451725
Ok(())
17461726
}
@@ -1789,7 +1769,9 @@ mod tests {
17891769
"| 1 | 4 | 7 | 10 | 4 | 70 |",
17901770
"+----+----+----+----+----+----+",
17911771
];
1792-
assert_batches_sorted_eq!(expected, &batches);
1772+
1773+
// Inner join output is expected to preserve both inputs order
1774+
assert_batches_eq!(expected, &batches);
17931775

17941776
// second part
17951777
let stream = join.execute(1, task_ctx.clone())?;
@@ -1804,7 +1786,8 @@ mod tests {
18041786
"+----+----+----+----+----+----+",
18051787
];
18061788

1807-
assert_batches_sorted_eq!(expected, &batches);
1789+
// Inner join output is expected to preserve both inputs order
1790+
assert_batches_eq!(expected, &batches);
18081791

18091792
Ok(())
18101793
}
@@ -2734,6 +2717,7 @@ mod tests {
27342717
None,
27352718
JoinSide::Left,
27362719
None,
2720+
false,
27372721
)?;
27382722

27392723
let mut left_ids = UInt64Builder::with_capacity(0);

datafusion/physical-plan/src/joins/symmetric_hash_join.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,7 @@ pub(crate) fn join_with_probe_batch(
770770
filter,
771771
build_hash_joiner.build_side,
772772
Some(build_hash_joiner.deleted_offset),
773+
false,
773774
)?;
774775
if need_to_produce_result_in_final(build_hash_joiner.build_side, join_type) {
775776
record_visited_indices(
@@ -882,6 +883,7 @@ impl OneSideHashJoiner {
882883
random_state,
883884
&mut self.hashes_buffer,
884885
self.deleted_offset,
886+
false,
885887
)?;
886888
Ok(())
887889
}

datafusion/physical-plan/src/joins/utils.rs

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use crate::{ColumnStatistics, ExecutionPlan, Partitioning, Statistics};
3131

3232
use arrow::array::{
3333
downcast_array, new_null_array, Array, BooleanBufferBuilder, UInt32Array,
34-
UInt32Builder, UInt64Array,
34+
UInt32BufferBuilder, UInt32Builder, UInt64Array, UInt64BufferBuilder,
3535
};
3636
use arrow::compute;
3737
use arrow::datatypes::{Field, Schema, SchemaBuilder};
@@ -151,6 +151,83 @@ pub trait JoinHashMapType {
151151
fn get_map(&self) -> &RawTable<(u64, u64)>;
152152
/// Returns a reference to the next.
153153
fn get_list(&self) -> &Self::NextType;
154+
155+
/// Updates hashmap from iterator of row indices & row hashes pairs.
156+
fn update_from_iter<'a>(
157+
&mut self,
158+
iter: impl Iterator<Item = (usize, &'a u64)>,
159+
deleted_offset: usize,
160+
) {
161+
let (mut_map, mut_list) = self.get_mut();
162+
for (row, hash_value) in iter {
163+
let item = mut_map.get_mut(*hash_value, |(hash, _)| *hash_value == *hash);
164+
if let Some((_, index)) = item {
165+
// Already exists: add index to next array
166+
let prev_index = *index;
167+
// Store new value inside hashmap
168+
*index = (row + 1) as u64;
169+
// Update chained Vec at row + offset with previous value
170+
mut_list[row - deleted_offset] = prev_index;
171+
} else {
172+
mut_map.insert(
173+
*hash_value,
174+
// store the value + 1 as 0 value reserved for end of list
175+
(*hash_value, (row + 1) as u64),
176+
|(hash, _)| *hash,
177+
);
178+
// chained list at (row + offset) is already initialized with 0
179+
// meaning end of list
180+
}
181+
}
182+
}
183+
184+
/// Returns all pairs of row indices matched by hash
185+
fn get_matched_indices<'a>(
186+
&self,
187+
iter: impl Iterator<Item = (usize, &'a u64)>,
188+
deleted_offset: Option<usize>,
189+
) -> (UInt64BufferBuilder, UInt32BufferBuilder) {
190+
let mut input_indices = UInt32BufferBuilder::new(0);
191+
let mut matched_indices = UInt64BufferBuilder::new(0);
192+
193+
let hash_map = self.get_map();
194+
let next_chain = self.get_list();
195+
for (row, hash_value) in iter {
196+
// Get the hash and find it in the build index
197+
198+
// For every item on the build and probe we check if it matches
199+
// This possibly contains rows with hash collisions,
200+
// So we have to check here whether rows are equal or not
201+
if let Some((_, index)) =
202+
hash_map.get(*hash_value, |(hash, _)| *hash_value == *hash)
203+
{
204+
let mut i = *index - 1;
205+
loop {
206+
let build_row_value = if let Some(offset) = deleted_offset {
207+
// This arguments means that we prune the next index way before here.
208+
if i < offset as u64 {
209+
// End of the list due to pruning
210+
break;
211+
}
212+
i - offset as u64
213+
} else {
214+
i
215+
};
216+
matched_indices.append(build_row_value);
217+
input_indices.append(row as u32);
218+
// Follow the chain to get the next index value
219+
let next = next_chain[build_row_value as usize];
220+
if next == 0 {
221+
// end of list
222+
break;
223+
}
224+
i = next - 1;
225+
}
226+
}
227+
}
228+
229+
(matched_indices, input_indices)
230+
}
154231
}
155232

156233
/// Implementation of `JoinHashMapType` for `JoinHashMap`.

0 commit comments

Comments
 (0)