diff --git a/table/multiget_context.h b/table/multiget_context.h index 76027a9520ff..90f3317e84f6 100644 --- a/table/multiget_context.h +++ b/table/multiget_context.h @@ -6,6 +6,7 @@ #pragma once #include #include +#include #include #include "db/dbformat.h" @@ -97,12 +98,11 @@ class MultiGetContext { // there is negligible benefit for batches exceeding this. Keeping this < 32 // simplifies iteration, as well as reduces the amount of stack allocations // that need to be performed - static const int MAX_BATCH_SIZE = 32; + static const int MAX_BATCH_SIZE = 1024; // A bitmask of at least MAX_BATCH_SIZE - 1 bits, so that // Mask{1} << MAX_BATCH_SIZE is well defined - using Mask = uint64_t; - static_assert(MAX_BATCH_SIZE < sizeof(Mask) * 8); + using Mask = std::bitset; MultiGetContext(autovector* sorted_keys, size_t begin, size_t num_keys, SequenceNumber snapshot, @@ -198,9 +198,8 @@ class MultiGetContext { Iterator(const Range* range, size_t idx) : range_(range), ctx_(range->ctx_), index_(idx) { while (index_ < range_->end_ && - (Mask{1} << index_) & (range_->ctx_->value_mask_ | range_->skip_mask_ | - range_->invalid_mask_)) + range_->invalid_mask_).test(index_)) index_++; } @@ -214,9 +213,8 @@ class MultiGetContext { Iterator& operator++() { while (++index_ < range_->end_ && - (Mask{1} << index_) & (range_->ctx_->value_mask_ | range_->skip_mask_ | - range_->invalid_mask_)) + range_->invalid_mask_).test(index_)) ; return *this; } @@ -264,8 +262,6 @@ class MultiGetContext { } skip_mask_ = mget_range.skip_mask_; invalid_mask_ = mget_range.invalid_mask_; - assert(start_ < 64); - assert(end_ < 64); } Range() = default; @@ -274,27 +270,27 @@ class MultiGetContext { Iterator end() const { return Iterator(this, end_); } - bool empty() const { return RemainingMask() == 0; } + bool empty() const { return RemainingMask().none(); } void SkipIndex(size_t index) { skip_mask_ |= Mask{1} << index; } void SkipKey(const Iterator& iter) { SkipIndex(iter.index_); } bool IsKeySkipped(const Iterator& iter) const { - return skip_mask_ & (Mask{1} << iter.index_); + return skip_mask_.test(iter.index_); } // Update the value_mask_ in MultiGetContext so its // immediately reflected in all the Range Iterators void MarkKeyDone(Iterator& iter) { - ctx_->value_mask_ |= (Mask{1} << iter.index_); + ctx_->value_mask_.set(iter.index_); } bool CheckKeyDone(Iterator& iter) const { - return ctx_->value_mask_ & (Mask{1} << iter.index_); + return ctx_->value_mask_.test(iter.index_); } - uint64_t KeysLeft() const { return BitsSetToOne(RemainingMask()); } + uint64_t KeysLeft() const { return RemainingMask().count(); } void AddSkipsFrom(const Range& other) { assert(ctx_ == other.ctx_); @@ -335,8 +331,6 @@ class MultiGetContext { skip_mask_ |= rhs.skip_mask_ & RangeMask(rhs.start_, rhs.end_); invalid_mask_ |= (rhs.invalid_mask_ | rhs.skip_mask_) & RangeMask(rhs.start_, rhs.end_); - assert(start_ < 64); - assert(end_ < 64); return *this; } @@ -373,22 +367,20 @@ class MultiGetContext { end_(num_keys), skip_mask_(0), invalid_mask_(0) { - assert(num_keys < 64); } static Mask RangeMask(size_t start, size_t end) { - return (((Mask{1} << (end - start)) - 1) << start); + return (Mask(0).flip() <<= (end - start)).flip() <<= start; } Mask RemainingMask() const { - return (((Mask{1} << end_) - 1) & ~((Mask{1} << start_) - 1) & - ~(ctx_->value_mask_ | skip_mask_)); + return RangeMask(start_, end_) & ~(ctx_->value_mask_ | skip_mask_); } size_t FindLastRemaining() const { Mask mask = RemainingMask(); - size_t index = (mask >>= start_) ? start_ : 0; - while (mask >>= 1) { + size_t index = (mask >>= start_).any() ? start_ : 0; + while ((mask >>= 1).any()) { index++; } return index;