@@ -34,7 +34,7 @@ JoinHashTable::JoinHashTable(ClientContext &context, const vector<JoinCondition>
34
34
: buffer_manager(BufferManager::GetBufferManager(context)), conditions(conditions_p),
35
35
build_types(std::move(btypes)), output_columns(output_columns_p), entry_size(0 ), tuple_size(0 ),
36
36
vfound(Value::BOOLEAN(false )), join_type(type_p), finalized(false ), has_null(false ),
37
- radix_bits(INITIAL_RADIX_BITS), partition_start( 0 ), partition_end( 0 ) {
37
+ radix_bits(INITIAL_RADIX_BITS) {
38
38
for (idx_t i = 0 ; i < conditions.size (); ++i) {
39
39
auto &condition = conditions[i];
40
40
D_ASSERT (condition.left ->return_type == condition.right ->return_type );
@@ -108,6 +108,8 @@ JoinHashTable::JoinHashTable(ClientContext &context, const vector<JoinCondition>
108
108
auto &config = ClientConfig::GetConfig (context);
109
109
single_join_error_on_multiple_rows = config.scalar_subquery_error_on_multiple_rows ;
110
110
}
111
+
112
+ InitializePartitionMasks ();
111
113
}
112
114
113
115
JoinHashTable::~JoinHashTable () {
@@ -1430,7 +1432,10 @@ idx_t JoinHashTable::GetRemainingSize() const {
1430
1432
1431
1433
idx_t count = 0 ;
1432
1434
idx_t data_size = 0 ;
1433
- for (idx_t partition_idx = partition_end; partition_idx < num_partitions; partition_idx++) {
1435
+ for (idx_t partition_idx = 0 ; partition_idx < num_partitions; partition_idx++) {
1436
+ if (completed_partitions.RowIsValidUnsafe (partition_idx)) {
1437
+ continue ;
1438
+ }
1434
1439
count += partitions[partition_idx]->Count ();
1435
1440
data_size += partitions[partition_idx]->SizeInBytes ();
1436
1441
}
@@ -1464,6 +1469,32 @@ void JoinHashTable::SetRepartitionRadixBits(const idx_t max_ht_size, const idx_t
1464
1469
radix_bits += added_bits;
1465
1470
sink_collection =
1466
1471
make_uniq<RadixPartitionedTupleData>(buffer_manager, layout, radix_bits, layout.ColumnCount () - 1 );
1472
+
1473
+ // Need to initialize again after changing the number of bits
1474
+ InitializePartitionMasks ();
1475
+ }
1476
+
1477
+ void JoinHashTable::InitializePartitionMasks () {
1478
+ const auto num_partitions = RadixPartitioning::NumberOfPartitions (radix_bits);
1479
+
1480
+ current_partitions.Initialize (num_partitions);
1481
+ current_partitions.SetAllInvalid (num_partitions);
1482
+
1483
+ completed_partitions.Initialize (num_partitions);
1484
+ completed_partitions.SetAllInvalid (num_partitions);
1485
+ }
1486
+
1487
+ idx_t JoinHashTable::CurrentPartitionCount () const {
1488
+ const auto num_partitions = RadixPartitioning::NumberOfPartitions (radix_bits);
1489
+ D_ASSERT (current_partitions.Capacity () == num_partitions);
1490
+ return current_partitions.CountValid (num_partitions);
1491
+ }
1492
+
1493
+ idx_t JoinHashTable::FinishedPartitionCount () const {
1494
+ const auto num_partitions = RadixPartitioning::NumberOfPartitions (radix_bits);
1495
+ D_ASSERT (completed_partitions.Capacity () == num_partitions);
1496
+ // We already marked the active partitions as done, so we have to subtract them here
1497
+ return completed_partitions.CountValid (num_partitions) - CurrentPartitionCount ();
1467
1498
}
1468
1499
1469
1500
void JoinHashTable::Repartition (JoinHashTable &global_ht) {
@@ -1477,6 +1508,7 @@ void JoinHashTable::Repartition(JoinHashTable &global_ht) {
1477
1508
void JoinHashTable::Reset () {
1478
1509
data_collection->Reset ();
1479
1510
hash_map.Reset ();
1511
+ current_partitions.SetAllInvalid (RadixPartitioning::NumberOfPartitions (radix_bits));
1480
1512
finalized = false ;
1481
1513
}
1482
1514
@@ -1486,33 +1518,46 @@ bool JoinHashTable::PrepareExternalFinalize(const idx_t max_ht_size) {
1486
1518
}
1487
1519
1488
1520
const auto num_partitions = RadixPartitioning::NumberOfPartitions (radix_bits);
1489
- if (partition_end == num_partitions) {
1490
- return false ;
1521
+ D_ASSERT (current_partitions.Capacity () == num_partitions);
1522
+ D_ASSERT (completed_partitions.Capacity () == num_partitions);
1523
+ D_ASSERT (current_partitions.CheckAllInvalid (num_partitions));
1524
+
1525
+ if (completed_partitions.CheckAllValid (num_partitions)) {
1526
+ return false ; // All partitions are done
1491
1527
}
1492
1528
1493
- // Start where we left off
1529
+ // Create vector with unfinished partition indices
1494
1530
auto &partitions = sink_collection->GetPartitions ();
1495
- partition_start = partition_end;
1531
+ vector<idx_t > partition_indices;
1532
+ partition_indices.reserve (num_partitions);
1533
+ for (idx_t partition_idx = 0 ; partition_idx < num_partitions; partition_idx++) {
1534
+ if (!completed_partitions.RowIsValidUnsafe (partition_idx)) {
1535
+ partition_indices.push_back (partition_idx);
1536
+ }
1537
+ }
1538
+ // Sort partitions by size, from small to large
1539
+ std::sort (partition_indices.begin (), partition_indices.end (), [&](const idx_t &lhs, const idx_t &rhs) {
1540
+ const auto lhs_size = partitions[lhs]->SizeInBytes () + PointerTableSize (partitions[lhs]->Count ());
1541
+ const auto rhs_size = partitions[rhs]->SizeInBytes () + PointerTableSize (partitions[rhs]->Count ());
1542
+ return lhs_size < rhs_size;
1543
+ });
1496
1544
1497
- // Determine how many partitions we can do next (at least one)
1545
+ // Determine which partitions should go next
1498
1546
idx_t count = 0 ;
1499
1547
idx_t data_size = 0 ;
1500
- idx_t partition_idx;
1501
- for ( partition_idx = partition_start; partition_idx < num_partitions; partition_idx++) {
1502
- auto incl_count = count + partitions[partition_idx]->Count ();
1503
- auto incl_data_size = data_size + partitions[partition_idx]->SizeInBytes ();
1504
- auto incl_ht_size = incl_data_size + PointerTableSize (incl_count);
1548
+ for ( const auto & partition_idx : partition_indices) {
1549
+ D_ASSERT (!completed_partitions. RowIsValidUnsafe ( partition_idx));
1550
+ const auto incl_count = count + partitions[partition_idx]->Count ();
1551
+ const auto incl_data_size = data_size + partitions[partition_idx]->SizeInBytes ();
1552
+ const auto incl_ht_size = incl_data_size + PointerTableSize (incl_count);
1505
1553
if (count > 0 && incl_ht_size > max_ht_size) {
1506
- break ;
1554
+ break ; // Always add at least one partition
1507
1555
}
1508
1556
count = incl_count;
1509
1557
data_size = incl_data_size;
1510
- }
1511
- partition_end = partition_idx;
1512
-
1513
- // Move the partitions to the main data collection
1514
- for (partition_idx = partition_start; partition_idx < partition_end; partition_idx++) {
1515
- data_collection->Combine (*partitions[partition_idx]);
1558
+ current_partitions.SetValidUnsafe (partition_idx); // Mark as currently active
1559
+ data_collection->Combine (*partitions[partition_idx]); // Move partition to the main data collection
1560
+ completed_partitions.SetValidUnsafe (partition_idx); // Also already mark as done
1516
1561
}
1517
1562
D_ASSERT (Count () == count);
1518
1563
@@ -1531,7 +1576,7 @@ void JoinHashTable::ProbeAndSpill(ScanStructure &scan_structure, DataChunk &prob
1531
1576
SelectionVector false_sel (STANDARD_VECTOR_SIZE);
1532
1577
const auto true_count =
1533
1578
RadixPartitioning::Select (hashes, FlatVector::IncrementalSelectionVector (), probe_keys.size (), radix_bits,
1534
- partition_end , &true_sel, &false_sel);
1579
+ current_partitions , &true_sel, &false_sel);
1535
1580
const auto false_count = probe_keys.size () - true_count;
1536
1581
1537
1582
// can't probe these values right now, append to spill
@@ -1596,21 +1641,25 @@ void ProbeSpill::Finalize() {
1596
1641
}
1597
1642
1598
1643
void ProbeSpill::PrepareNextProbe () {
1644
+ global_spill_collection.reset ();
1599
1645
auto &partitions = global_partitions->GetPartitions ();
1600
- if (partitions.empty () || ht.partition_start == partitions.size ()) {
1646
+ if (partitions.empty () || ht.current_partitions . CheckAllInvalid ( partitions.size () )) {
1601
1647
// Can't probe, just make an empty one
1602
1648
global_spill_collection =
1603
1649
make_uniq<ColumnDataCollection>(BufferManager::GetBufferManager (context), probe_types);
1604
1650
} else {
1605
- // Move specific partitions to the global spill collection
1606
- global_spill_collection = std::move (partitions[ht.partition_start ]);
1607
- for (idx_t i = ht.partition_start + 1 ; i < ht.partition_end ; i++) {
1608
- auto &partition = partitions[i];
1609
- if (global_spill_collection->Count () == 0 ) {
1651
+ // Move current partitions to the global spill collection
1652
+ for (idx_t partition_idx = 0 ; partition_idx < partitions.size (); partition_idx++) {
1653
+ if (!ht.current_partitions .RowIsValidUnsafe (partition_idx)) {
1654
+ continue ;
1655
+ }
1656
+ auto &partition = partitions[partition_idx];
1657
+ if (!global_spill_collection) {
1610
1658
global_spill_collection = std::move (partition);
1611
- } else {
1659
+ } else if (partition-> Count () != 0 ) {
1612
1660
global_spill_collection->Combine (*partition);
1613
1661
}
1662
+ partition.reset ();
1614
1663
}
1615
1664
}
1616
1665
consumer = make_uniq<ColumnDataConsumer>(*global_spill_collection, column_ids);
0 commit comments