Skip to content

Commit

Permalink
remove absl/random and absl/memory, add absl::btree_map
Browse files Browse the repository at this point in the history
  • Loading branch information
taku910 committed Jan 7, 2024
1 parent adf9e81 commit f5c7363
Show file tree
Hide file tree
Showing 15 changed files with 103 additions and 211 deletions.
9 changes: 4 additions & 5 deletions src/bpe_model_trainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

#include <cstdint>
#include <limits>
#include <set>
#include <string>
#include <vector>

#include "sentencepiece_model.pb.h"
#include "third_party/absl/container/btree_set.h"
#include "third_party/absl/container/flat_hash_map.h"
#include "trainer_interface.h"

Expand Down Expand Up @@ -51,7 +51,7 @@ class Trainer : public TrainerInterface {

// Position list. Use set so that we can keep the order of occurrence.
// See EncodePos/DecodePos.
std::set<uint64_t> positions;
absl::btree_set<uint64_t> positions;

bool IsBigram() const { return left != nullptr && right != nullptr; }
std::string ToString() const;
Expand All @@ -72,8 +72,7 @@ class Trainer : public TrainerInterface {
CHECK_LE(l, std::numeric_limits<uint16_t>::max());
CHECK_LE(r, std::numeric_limits<uint16_t>::max());
const uint64_t n = (static_cast<uint64_t>(sid) << 32) |
(static_cast<uint64_t>(l) << 16) |
r;
(static_cast<uint64_t>(l) << 16) | r;
return n;
}

Expand Down Expand Up @@ -118,7 +117,7 @@ class Trainer : public TrainerInterface {
absl::flat_hash_map<uint64_t, Symbol *> symbols_cache_;

// Set of symbols from which we find the best symbol in each iteration.
std::set<Symbol *> active_symbols_;
absl::btree_set<Symbol *> active_symbols_;

// Stores symbols allocated in heap so that we can delete them at onece.
std::vector<Symbol *> allocated_;
Expand Down
5 changes: 2 additions & 3 deletions src/filesystem.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <iostream>
#include <memory>

#include "third_party/absl/memory/memory.h"
#include "util.h"

#if defined(OS_WIN) && defined(UNICODE) && defined(_UNICODE)
Expand Down Expand Up @@ -105,12 +104,12 @@ using DefaultWritableFile = PosixWritableFile;

std::unique_ptr<ReadableFile> NewReadableFile(absl::string_view filename,
bool is_binary) {
return absl::make_unique<DefaultReadableFile>(filename, is_binary);
return std::make_unique<DefaultReadableFile>(filename, is_binary);
}

std::unique_ptr<WritableFile> NewWritableFile(absl::string_view filename,
bool is_binary) {
return absl::make_unique<DefaultWritableFile>(filename, is_binary);
return std::make_unique<DefaultWritableFile>(filename, is_binary);
}

} // namespace filesystem
Expand Down
14 changes: 7 additions & 7 deletions src/model_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.!

#include "model_factory.h"

#include "bpe_model.h"
#include "char_model.h"
#include "model_factory.h"
#include "third_party/absl/memory/memory.h"
#include "unigram_model.h"
#include "word_model.h"

Expand All @@ -28,23 +28,23 @@ std::unique_ptr<ModelInterface> ModelFactory::Create(

switch (trainer_spec.model_type()) {
case TrainerSpec::UNIGRAM:
return absl::make_unique<unigram::Model>(model_proto);
return std::make_unique<unigram::Model>(model_proto);
break;
case TrainerSpec::BPE:
return absl::make_unique<bpe::Model>(model_proto);
return std::make_unique<bpe::Model>(model_proto);
break;
case TrainerSpec::WORD:
return absl::make_unique<word::Model>(model_proto);
return std::make_unique<word::Model>(model_proto);
break;
case TrainerSpec::CHAR:
return absl::make_unique<character::Model>(model_proto);
return std::make_unique<character::Model>(model_proto);
break;
default:
LOG(ERROR) << "Unknown model_type: " << trainer_spec.model_type();
return nullptr;
break;
}

return absl::make_unique<unigram::Model>(model_proto);
return std::make_unique<unigram::Model>(model_proto);
}
} // namespace sentencepiece
3 changes: 1 addition & 2 deletions src/model_interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include <algorithm>

#include "sentencepiece_model.pb.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/absl/strings/str_format.h"
#include "util.h"

Expand Down Expand Up @@ -148,7 +147,7 @@ void ModelInterface::InitializePieces() {
}
}

matcher_ = absl::make_unique<normalizer::PrefixMatcher>(user_defined_symbols);
matcher_ = std::make_unique<normalizer::PrefixMatcher>(user_defined_symbols);
}

std::vector<absl::string_view> SplitIntoWords(absl::string_view text,
Expand Down
5 changes: 2 additions & 3 deletions src/normalizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <vector>

#include "common.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/absl/strings/match.h"
#include "third_party/absl/strings/string_view.h"
#include "third_party/absl/strings/strip.h"
Expand Down Expand Up @@ -58,7 +57,7 @@ void Normalizer::Init() {
if (!status_.ok()) return;

// Reads the body of double array.
trie_ = absl::make_unique<Darts::DoubleArray>();
trie_ = std::make_unique<Darts::DoubleArray>();

// The second arg of set_array is not the size of blob,
// but the number of double array units.
Expand Down Expand Up @@ -314,7 +313,7 @@ PrefixMatcher::PrefixMatcher(const std::set<absl::string_view> &dic) {
std::vector<const char *> key;
key.reserve(dic.size());
for (const auto &it : dic) key.push_back(it.data());
trie_ = absl::make_unique<Darts::DoubleArray>();
trie_ = std::make_unique<Darts::DoubleArray>();
if (trie_->build(key.size(), const_cast<char **>(&key[0]), nullptr,
nullptr) != 0) {
LOG(ERROR) << "Failed to build the TRIE for PrefixMatcher";
Expand Down
11 changes: 5 additions & 6 deletions src/sentencepiece_processor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include "model_interface.h"
#include "normalizer.h"
#include "sentencepiece.pb.h"
#include "third_party/absl/memory/memory.h"
#include "third_party/absl/strings/numbers.h"
#include "third_party/absl/strings/str_cat.h"
#include "third_party/absl/strings/str_join.h"
Expand Down Expand Up @@ -217,7 +216,7 @@ SentencePieceProcessor::SentencePieceProcessor() {}
SentencePieceProcessor::~SentencePieceProcessor() {}

util::Status SentencePieceProcessor::Load(absl::string_view filename) {
auto model_proto = absl::make_unique<ModelProto>();
auto model_proto = std::make_unique<ModelProto>();
RETURN_IF_ERROR(io::LoadModelProto(filename, model_proto.get()));
return Load(std::move(model_proto));
}
Expand All @@ -227,14 +226,14 @@ void SentencePieceProcessor::LoadOrDie(absl::string_view filename) {
}

util::Status SentencePieceProcessor::Load(const ModelProto &model_proto) {
auto model_proto_copy = absl::make_unique<ModelProto>();
auto model_proto_copy = std::make_unique<ModelProto>();
*model_proto_copy = model_proto;
return Load(std::move(model_proto_copy));
}

util::Status SentencePieceProcessor::LoadFromSerializedProto(
absl::string_view serialized) {
auto model_proto = absl::make_unique<ModelProto>();
auto model_proto = std::make_unique<ModelProto>();
CHECK_OR_RETURN(
model_proto->ParseFromArray(serialized.data(), serialized.size()));
return Load(std::move(model_proto));
Expand All @@ -244,11 +243,11 @@ util::Status SentencePieceProcessor::Load(
std::unique_ptr<ModelProto> model_proto) {
model_proto_ = std::move(model_proto);
model_ = ModelFactory::Create(*model_proto_);
normalizer_ = absl::make_unique<normalizer::Normalizer>(
normalizer_ = std::make_unique<normalizer::Normalizer>(
model_proto_->normalizer_spec(), model_proto_->trainer_spec());
if (model_proto_->has_denormalizer_spec() &&
!model_proto_->denormalizer_spec().precompiled_charsmap().empty()) {
denormalizer_ = absl::make_unique<normalizer::Normalizer>(
denormalizer_ = std::make_unique<normalizer::Normalizer>(
model_proto_->denormalizer_spec());
}

Expand Down
Loading

0 comments on commit f5c7363

Please sign in to comment.