Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/query/expression/src/aggregate/aggregate_hashtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ impl AggregateHashTable {
group_hash_columns(group_columns, &mut state.group_hashes);

let new_group_count = if self.direct_append {
for idx in 0..row_count {
state.empty_vector[idx] = idx;
for i in 0..row_count {
state.empty_vector[i] = i.into();
}
self.payload.append_rows(state, row_count, group_columns);
row_count
Expand Down
63 changes: 37 additions & 26 deletions src/query/expression/src/aggregate/hash_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt::Debug;

use super::payload_row::CompareState;
use super::PartitionedPayload;
use super::ProbeState;
use super::RowPtr;
Expand Down Expand Up @@ -94,7 +97,7 @@ const SALT_MASK: u64 = 0xFFFF000000000000;
const POINTER_MASK: u64 = 0x0000FFFFFFFFFFFF;

// The high 16 bits are the salt, the low 48 bits are the pointer address
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[derive(Clone, Copy, PartialEq, Eq, Default)]
pub(super) struct Entry(pub(super) u64);

impl Entry {
Expand Down Expand Up @@ -133,6 +136,15 @@ impl Entry {
}
}

impl Debug for Entry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Entry")
.field(&self.get_salt())
.field(&self.get_pointer())
.finish()
}
}

pub(super) trait TableAdapter {
fn append_rows(&mut self, state: &mut ProbeState, new_entry_count: usize);

Expand All @@ -152,16 +164,10 @@ impl HashIndex {
mut adapter: impl TableAdapter,
) -> usize {
for (i, row) in state.no_match_vector[..row_count].iter_mut().enumerate() {
*row = i;
*row = i.into();
state.slots[i] = self.init_slot(state.group_hashes[i]);
}

let mut slots = state.get_temp();
slots.extend(
state.group_hashes[..row_count]
.iter()
.map(|hash| self.init_slot(*hash)),
);

let mut new_group_count = 0;
let mut remaining_entries = row_count;

Expand All @@ -172,11 +178,11 @@ impl HashIndex {

// 1. inject new_group_count, new_entry_count, need_compare_count, no_match_count
for row in state.no_match_vector[..remaining_entries].iter().copied() {
let slot = &mut slots[row];
let hash = state.group_hashes[row];

let slot = &mut state.slots[row];
let is_new;
(*slot, is_new) = self.find_or_insert(*slot, Entry::hash_to_salt(hash));

let salt = Entry::hash_to_salt(state.group_hashes[row]);
(*slot, is_new) = self.find_or_insert(*slot, salt);

if is_new {
state.empty_vector[new_entry_count] = row;
Expand All @@ -194,7 +200,7 @@ impl HashIndex {
adapter.append_rows(state, new_entry_count);

for row in state.empty_vector[..new_entry_count].iter().copied() {
let entry = self.mut_entry(slots[row]);
let entry = self.mut_entry(state.slots[row]);
entry.set_pointer(state.addresses[row]);
debug_assert_eq!(entry.get_pointer(), state.addresses[row]);
}
Expand All @@ -206,7 +212,7 @@ impl HashIndex {
.iter()
.copied()
{
let entry = self.mut_entry(slots[row]);
let entry = self.mut_entry(state.slots[row]);

debug_assert!(entry.is_occupied());
debug_assert_eq!(entry.get_salt(), (state.group_hashes[row] >> 48) as u16);
Expand All @@ -219,7 +225,7 @@ impl HashIndex {

// 5. Linear probing, just increase iter_times
for row in state.no_match_vector[..no_match_count].iter().copied() {
let slot = &mut slots[row];
let slot = &mut state.slots[row];
*slot += 1;
if *slot >= self.capacity {
*slot = 0;
Expand All @@ -228,7 +234,6 @@ impl HashIndex {
remaining_entries = no_match_count;
}

state.save_temp(slots);
self.count += new_group_count;

new_group_count
Expand All @@ -251,7 +256,13 @@ impl<'a> TableAdapter for AdapterImpl<'a> {
need_compare_count: usize,
no_match_count: usize,
) -> usize {
state.row_match_columns(
// todo: compare hash first if NECESSARY
CompareState {
address: &state.addresses,
compare: &mut state.group_compare_vector,
no_matched: &mut state.no_match_vector,
}
.row_match_entries(
self.group_columns,
&self.payload.row_layout,
(need_compare_count, no_match_count),
Expand Down Expand Up @@ -284,8 +295,10 @@ mod tests {
}

fn init_state(&self) -> ProbeState {
let mut state = ProbeState::default();
state.row_count = self.incoming.len();
let mut state = ProbeState {
row_count: self.incoming.len(),
..Default::default()
};

for (i, (_, hash)) in self.incoming.iter().enumerate() {
state.group_hashes[i] = *hash
Expand Down Expand Up @@ -323,12 +336,12 @@ mod tests {

impl TableAdapter for &mut TestTableAdapter {
fn append_rows(&mut self, state: &mut ProbeState, new_entry_count: usize) {
for row in state.empty_vector[..new_entry_count].iter().copied() {
let (key, hash) = self.incoming[row];
for row in state.empty_vector[..new_entry_count].iter() {
let (key, hash) = self.incoming[*row];
let value = key + 20;

self.payload.push((key, hash, value));
state.addresses[row] = self.get_row_ptr(true, row);
state.addresses[*row] = self.get_row_ptr(true, row.to_usize());
}
}

Expand All @@ -344,9 +357,7 @@ mod tests {
{
let incoming = self.incoming[row];

let row_ptr = state.addresses[row];

let (key, hash, _) = self.get_payload(row_ptr);
let (key, hash, _) = self.get_payload(state.addresses[row]);

const POINTER_MASK: u64 = 0x0000FFFFFFFFFFFF;
assert_eq!(incoming.1 | POINTER_MASK, hash | POINTER_MASK);
Expand Down
5 changes: 3 additions & 2 deletions src/query/expression/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ use hash_index::Entry;
pub use partitioned_payload::*;
pub use payload::*;
pub use payload_flush::*;
pub use probe_state::*;
use row_ptr::RowPtr;
pub use probe_state::ProbeState;
use probe_state::*;
use row_ptr::*;

// A batch size to probe, flush, repartition, etc.
pub(crate) const BATCH_SIZE: usize = 2048;
Expand Down
105 changes: 62 additions & 43 deletions src/query/expression/src/aggregate/partitioned_payload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,28 @@ use crate::ProjectedBlock;
use crate::StatesLayout;
use crate::BATCH_SIZE;

#[derive(Debug, Clone, Copy)]
struct PartitionMask {
mask: u64,
shift: u64,
}

impl PartitionMask {
fn new(partition_count: u64) -> Self {
let radix_bits = partition_count.trailing_zeros() as u64;
debug_assert_eq!(1 << radix_bits, partition_count);

let shift = 48 - radix_bits;
let mask = ((1 << radix_bits) - 1) << shift;

Self { mask, shift }
}

pub fn index(&self, hash: u64) -> usize {
((hash & self.mask) >> self.shift) as _
}
}

pub struct PartitionedPayload {
pub payloads: Vec<Payload>,
pub group_types: Vec<DataType>,
Expand All @@ -37,9 +59,7 @@ pub struct PartitionedPayload {

pub arenas: Vec<Arc<Bump>>,

partition_count: u64,
mask_v: u64,
shift_v: u64,
partition_mask: PartitionMask,
}

unsafe impl Send for PartitionedPayload {}
Expand All @@ -52,9 +72,6 @@ impl PartitionedPayload {
partition_count: u64,
arenas: Vec<Arc<Bump>>,
) -> Self {
let radix_bits = partition_count.trailing_zeros() as u64;
debug_assert_eq!(1 << radix_bits, partition_count);

let states_layout = if !aggrs.is_empty() {
Some(get_states_layout(&aggrs).unwrap())
} else {
Expand All @@ -72,7 +89,7 @@ impl PartitionedPayload {
})
.collect_vec();

let offsets = RowLayout {
let row_layout = RowLayout {
states_layout,
..payloads[0].row_layout.clone()
};
Expand All @@ -81,12 +98,10 @@ impl PartitionedPayload {
payloads,
group_types,
aggrs,
row_layout: offsets,
partition_count,
row_layout,

arenas,
mask_v: mask(radix_bits),
shift_v: shift(radix_bits),
partition_mask: PartitionMask::new(partition_count),
}
}

Expand Down Expand Up @@ -119,10 +134,10 @@ impl PartitionedPayload {
state.reset_partitions(self.partition_count());
for &row in &state.empty_vector[..new_group_rows] {
let hash = state.group_hashes[row];
let partition_idx = ((hash & self.mask_v) >> self.shift_v) as usize;
let partition_idx = self.partition_mask.index(hash);
let (count, sel) = &mut state.partition_entries[partition_idx];

sel[*count] = row;
sel[*count as usize] = row;
*count += 1;
}

Expand All @@ -133,7 +148,7 @@ impl PartitionedPayload {
{
if *count > 0 {
payload.reserve_append_rows(
&sel[..*count],
&sel[..*count as _],
&state.group_hashes,
&mut state.addresses,
&mut state.page_index,
Expand All @@ -149,19 +164,27 @@ impl PartitionedPayload {
return self;
}

let mut new_partition_payload = PartitionedPayload::new(
self.group_types.clone(),
self.aggrs.clone(),
new_partition_count as u64,
self.arenas.clone(),
);
let PartitionedPayload {
payloads,
group_types,
aggrs,
arenas,
..
} = self;

let mut new_partition_payload =
PartitionedPayload::new(group_types, aggrs, new_partition_count as u64, arenas);

state.clear();
for payload in payloads.into_iter() {
new_partition_payload.combine_single(payload, state, None)
}

new_partition_payload.combine(self, state);
new_partition_payload
}

pub fn combine(&mut self, other: PartitionedPayload, state: &mut PayloadFlushState) {
if other.partition_count == self.partition_count {
if other.partition_count() == self.partition_count() {
for (l, r) in self.payloads.iter_mut().zip(other.payloads.into_iter()) {
l.combine(r);
}
Expand All @@ -184,7 +207,7 @@ impl PartitionedPayload {
return;
}

if self.partition_count == 1 {
if self.partition_count() == 1 {
self.payloads[0].combine(other);
} else {
flush_state.clear();
Expand All @@ -194,13 +217,19 @@ impl PartitionedPayload {
// copy rows
let state = &*flush_state.probe_state;

for partition in (0..self.partition_count as usize)
.filter(|x| only_bucket.is_none() || only_bucket == Some(*x))
{
let (count, sel) = &state.partition_entries[partition];
if *count > 0 {
let payload = &mut self.payloads[partition];
payload.copy_rows(&sel[..*count], &flush_state.addresses);
match only_bucket {
Some(i) => {
let (count, sel) = &state.partition_entries[i];
self.payloads[i].copy_rows(&sel[..*count as _], &flush_state.addresses);
}
None => {
for ((count, sel), payload) in
state.partition_entries.iter().zip(self.payloads.iter_mut())
{
if *count > 0 {
payload.copy_rows(&sel[..*count as _], &flush_state.addresses);
}
}
}
}
}
Expand Down Expand Up @@ -236,10 +265,10 @@ impl PartitionedPayload {
flush_state.addresses[idx] = row_ptr;

let hash = row_ptr.hash(&self.row_layout);
let partition_idx = ((hash & self.mask_v) >> self.shift_v) as usize;
let partition_idx = self.partition_mask.index(hash);

let (count, sel) = &mut state.partition_entries[partition_idx];
sel[*count] = idx;
sel[*count as usize] = idx.into();
*count += 1;
}
flush_state.flush_page_row = end;
Expand All @@ -253,7 +282,7 @@ impl PartitionedPayload {

#[inline]
pub fn partition_count(&self) -> usize {
self.partition_count as usize
self.payloads.len()
}

#[allow(dead_code)]
Expand All @@ -266,13 +295,3 @@ impl PartitionedPayload {
self.payloads.iter().map(|x| x.memory_size()).sum()
}
}

#[inline]
fn shift(radix_bits: u64) -> u64 {
48 - radix_bits
}

#[inline]
fn mask(radix_bits: u64) -> u64 {
((1 << radix_bits) - 1) << shift(radix_bits)
}
Loading
Loading