diff --git a/src/compile_charsmap_main.cc b/src/compile_charsmap_main.cc index 7fc0275f..85555a9c 100644 --- a/src/compile_charsmap_main.cc +++ b/src/compile_charsmap_main.cc @@ -73,7 +73,7 @@ std::string ToHexUInt64Array( std::string ToHexData(absl::string_view data) { const char *begin = data.data(); const char *end = data.data() + data.size(); - constexpr char kHex[] = "0123456789ABCDEF"; + constexpr absl::string_view kHex = "0123456789ABCDEF"; constexpr size_t kNumOfBytesOnOneLine = 20; size_t output_count = 0; diff --git a/src/model_interface.cc b/src/model_interface.cc index 11664135..351413bf 100644 --- a/src/model_interface.cc +++ b/src/model_interface.cc @@ -49,13 +49,11 @@ absl::string_view ModelInterface::pad_piece() const { #undef RETURN_PIECE int ModelInterface::PieceToId(absl::string_view piece) const { - auto it = reserved_id_map_.find(piece); - if (it != reserved_id_map_.end()) { + if (auto it = reserved_id_map_.find(piece); it != reserved_id_map_.end()) { return it->second; } - auto it2 = pieces_.find(piece); - if (it2 != pieces_.end()) { - return it2->second; + if (auto it = pieces_.find(piece); it != pieces_.end()) { + return it->second; } return unk_id_; } @@ -160,7 +158,7 @@ std::vector SplitIntoWords(absl::string_view text, const char *end = text.data() + text.size(); // Space symbol (U+2581) - const absl::string_view kSpaceSymbol = "\xe2\x96\x81"; + constexpr absl::string_view kSpaceSymbol = "\xe2\x96\x81"; bool in_ws_sequence = false; std::vector result; @@ -223,12 +221,12 @@ int PieceToByte(absl::string_view piece) { } return m; }(); - const auto it = kMap->find(piece); - if (it == kMap->end()) { - return -1; - } else { + + if (const auto it = kMap->find(piece); it != kMap->end()) { return it->second; } + + return -1; } } // namespace sentencepiece diff --git a/src/normalizer.cc b/src/normalizer.cc index 2a1f0787..521d81db 100644 --- a/src/normalizer.cc +++ b/src/normalizer.cc @@ -103,7 +103,7 @@ util::Status Normalizer::Normalize(absl::string_view input, } // Reserves the output buffer to avoid re-allocations. - const size_t kReservedSize = input.size() * 3; + const size_t kReservedSize = input.size() * 1.5; normalized->reserve(kReservedSize); if (norm_to_orig) norm_to_orig->reserve(kReservedSize); @@ -191,14 +191,12 @@ std::string Normalizer::Normalize(absl::string_view input) const { std::pair Normalizer::NormalizePrefix( absl::string_view input) const { - std::pair result; - - if (input.empty()) return result; + if (input.empty()) return {}; if (matcher_ != nullptr) { bool found = false; const int mblen = matcher_->PrefixMatch(input, &found); - if (found) return std::make_pair(input.substr(0, mblen), mblen); + if (found) return {input.substr(0, mblen), mblen}; } size_t longest_length = 0; @@ -225,6 +223,7 @@ std::pair Normalizer::NormalizePrefix( } } + std::pair result; if (longest_length == 0 || longest_length > input.size() || longest_value >= normalized_.size()) { size_t length = 0; @@ -234,8 +233,8 @@ std::pair Normalizer::NormalizePrefix( // which is a valid Unicode of three bytes in utf8, // but here we only consume one byte. result.second = 1; - static const char kReplacementChar[] = "\xEF\xBF\xBD"; - result.first = absl::string_view(kReplacementChar); + static constexpr absl::string_view kReplacementChar = "\xEF\xBF\xBD"; + result.first = kReplacementChar; } else { result.second = length; result.first = absl::string_view(input.data(), result.second); diff --git a/src/sentencepiece_processor.cc b/src/sentencepiece_processor.cc index 73580cee..e36314bd 100644 --- a/src/sentencepiece_processor.cc +++ b/src/sentencepiece_processor.cc @@ -40,19 +40,27 @@ #include "unigram_model.h" #include "util.h" +#ifdef _USE_EXTERNAL_PROTOBUF +#include "google/protobuf/arena.h" +#else +#include "third_party/protobuf-lite/google/protobuf/arena.h" +#endif + +using ::google::protobuf::Arena; + namespace sentencepiece { namespace { // Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK). -const char kSpaceSymbol[] = "\xe2\x96\x81"; +constexpr absl::string_view kSpaceSymbol = "\xe2\x96\x81"; // Encodes into U+2047 (DOUBLE QUESTION MARK), // since this character can be useful both for user and // developer. We can easily figure out that is emitted. -const char kDefaultUnknownSymbol[] = " \xE2\x81\x87 "; +constexpr absl::string_view kDefaultUnknownSymbol = " \xE2\x81\x87 "; // REPLACEMENT CHARACTER (U+FFFD) in UTF-8. -const char kReplacementCharacter[] = "\xef\xbf\xbd"; +constexpr absl::string_view kReplacementCharacter = "\xef\xbf\xbd"; std::vector ToPieceArray(const std::vector &v) { std::vector out(v.size()); @@ -381,9 +389,10 @@ util::Status SentencePieceProcessor::Encode( absl::string_view input, std::vector *pieces) const { RET_CHECK_STATUS_STL(pieces); - SentencePieceText spt; - RETURN_IF_ERROR(Encode(input, &spt)); - for (const auto &sp : spt.pieces()) { + Arena arena; + auto *spt = Arena::Create(&arena); + RETURN_IF_ERROR(Encode(input, spt)); + for (const auto &sp : spt->pieces()) { pieces->emplace_back(sp.piece()); } @@ -394,11 +403,58 @@ util::Status SentencePieceProcessor::Encode(absl::string_view input, std::vector *ids) const { RET_CHECK_STATUS_STL(ids); - SentencePieceText spt; - RETURN_IF_ERROR(Encode(input, &spt)); - ids->reserve(spt.pieces().size()); - for (const auto &sp : spt.pieces()) { - ids->emplace_back(sp.id()); + // The following is a pared down version of PopulateSentencePieceText, that + // only populates the ids; skipping the surface and begin/end fields as they + // will be thrown away otherwise. + std::string normalized; + + RETURN_IF_ERROR( + normalizer_->Normalize(input, &normalized, /*norm_to_orig=*/nullptr)); + const EncodeResult result = model_->Encode(normalized); + const bool byte_fallback_enabled = model_->ByteFallbackEnabled(); + + bool is_prev_unk = false; + ids->reserve(result.size()); + + for (const auto &[w, id] : result) { + RET_CHECK(!w.empty()) << "Empty piece is not allowed."; + if (IsControl(id)) { + ids->emplace_back(id); + is_prev_unk = false; + } else { + const bool is_unk = IsUnknown(id); + if (is_unk && byte_fallback_enabled) { + for (size_t i = 0; i < w.size(); ++i) { + const auto sp_id = + model_->PieceToId(ByteToPiece(static_cast(w[i]))); + ids->emplace_back(sp_id); + } + } else { + // Merge continuous runs of unknown pieces. + if (!is_prev_unk || !is_unk) { + ids->emplace_back(id); + } + } + is_prev_unk = is_unk; + } + } + + // Inlining ApplyExtraOptions but just the ids part. + for (const auto &extra_option : encode_extra_options_) { + switch (extra_option) { + case REVERSE: + std::reverse(ids->begin(), ids->end()); + break; + case EOS: + ids->emplace_back(PieceToId(model_->eos_piece())); + break; + case BOS: + ids->insert(ids->begin(), PieceToId(model_->bos_piece())); + break; + default: + ids->clear(); + return util::InternalError("unknown extra_option type."); + } } return util::OkStatus(); @@ -414,9 +470,12 @@ util::Status SentencePieceProcessor::Decode( std::string *detokenized) const { RET_CHECK_STATUS_STL(detokenized); - SentencePieceText spt; - RETURN_IF_ERROR(Decode(pieces, &spt)); - *detokenized = std::move(*spt.mutable_text()); + // Allocate SentencePieceText on an arena to improve allocation and + // deallocation costs. + Arena arena; + auto *spt = Arena::Create(&arena); + RETURN_IF_ERROR(Decode(pieces, spt)); + *detokenized = std::move(*spt->mutable_text()); return util::OkStatus(); } @@ -425,9 +484,12 @@ util::Status SentencePieceProcessor::Decode(const std::vector &ids, std::string *detokenized) const { RET_CHECK_STATUS_STL(detokenized); - SentencePieceText spt; - RETURN_IF_ERROR(Decode(ids, &spt)); - *detokenized = std::move(*spt.mutable_text()); + // Allocate SentencePieceText on an arena to improve allocation and + // deallocation costs. + Arena arena; + auto *spt = Arena::Create(&arena); + RETURN_IF_ERROR(Decode(ids, spt)); + *detokenized = std::move(*spt->mutable_text()); return util::OkStatus(); } @@ -437,10 +499,11 @@ util::Status SentencePieceProcessor::NBestEncode( std::vector> *pieces) const { RET_CHECK_STATUS_STL(pieces); - NBestSentencePieceText spt; - RETURN_IF_ERROR(NBestEncode(input, nbest_size, &spt)); - pieces->reserve(spt.nbests().size()); - for (const auto &nbest : spt.nbests()) { + Arena arena; + auto *spt = Arena::Create(&arena); + RETURN_IF_ERROR(NBestEncode(input, nbest_size, spt)); + pieces->reserve(spt->nbests().size()); + for (const auto &nbest : spt->nbests()) { std::vector &result = pieces->emplace_back(); result.reserve(nbest.pieces().size()); for (const auto &sp : nbest.pieces()) { @@ -456,10 +519,11 @@ util::Status SentencePieceProcessor::NBestEncode( std::vector> *ids) const { RET_CHECK_STATUS_STL(ids); - NBestSentencePieceText spt; - RETURN_IF_ERROR(NBestEncode(input, nbest_size, &spt)); - ids->reserve(spt.nbests().size()); - for (const auto &nbest : spt.nbests()) { + Arena arena; + auto *spt = Arena::Create(&arena); + RETURN_IF_ERROR(NBestEncode(input, nbest_size, spt)); + ids->reserve(spt->nbests().size()); + for (const auto &nbest : spt->nbests()) { std::vector &result = ids->emplace_back(); result.reserve(nbest.pieces().size()); for (const auto &sp : nbest.pieces()) { @@ -475,10 +539,11 @@ util::Status SentencePieceProcessor::SampleEncode( std::vector *pieces) const { RET_CHECK_STATUS_STL(pieces); - SentencePieceText spt; - RETURN_IF_ERROR(SampleEncode(input, nbest_size, alpha, &spt)); - pieces->reserve(spt.pieces().size()); - for (const auto &sp : spt.pieces()) { + Arena arena; + auto *spt = Arena::Create(&arena); + RETURN_IF_ERROR(SampleEncode(input, nbest_size, alpha, spt)); + pieces->reserve(spt->pieces().size()); + for (const auto &sp : spt->pieces()) { pieces->emplace_back(sp.piece()); } @@ -490,9 +555,10 @@ util::Status SentencePieceProcessor::SampleEncode(absl::string_view input, std::vector *ids) const { RET_CHECK_STATUS_STL(ids); - SentencePieceText spt; - RETURN_IF_ERROR(SampleEncode(input, nbest_size, alpha, &spt)); - for (const auto &sp : spt.pieces()) { + Arena arena; + auto *spt = Arena::Create(&arena); + RETURN_IF_ERROR(SampleEncode(input, nbest_size, alpha, spt)); + for (const auto &sp : spt->pieces()) { ids->emplace_back(sp.id()); } @@ -505,20 +571,21 @@ util::Status SentencePieceProcessor::SampleEncodeAndScore( std::vector, float>> *pieces) const { RET_CHECK_STATUS_STL(pieces); - NBestSentencePieceText spt; + Arena arena; + auto *spt = Arena::Create(&arena); RETURN_IF_ERROR( - SampleEncodeAndScore(input, num_samples, alpha, wor, include_best, &spt)); + SampleEncodeAndScore(input, num_samples, alpha, wor, include_best, spt)); pieces->clear(); - pieces->reserve(spt.nbests_size()); + pieces->reserve(spt->nbests_size()); - for (const auto &nbest : spt.nbests()) { + for (const auto &nbest : spt->nbests()) { std::vector result; result.reserve(nbest.pieces_size()); for (const auto &sp : nbest.pieces()) { result.emplace_back(sp.piece()); } - pieces->emplace_back(result, nbest.score()); + pieces->emplace_back(std::move(result), nbest.score()); } return util::OkStatus(); @@ -530,20 +597,21 @@ util::Status SentencePieceProcessor::SampleEncodeAndScore( std::vector, float>> *ids) const { RET_CHECK_STATUS_STL(ids); - NBestSentencePieceText spt; + Arena arena; + auto *spt = Arena::Create(&arena); RETURN_IF_ERROR( - SampleEncodeAndScore(input, num_samples, alpha, wor, include_best, &spt)); + SampleEncodeAndScore(input, num_samples, alpha, wor, include_best, spt)); ids->clear(); - ids->reserve(spt.nbests_size()); + ids->reserve(spt->nbests_size()); - for (const auto &nbest : spt.nbests()) { + for (const auto &nbest : spt->nbests()) { std::vector result; result.reserve(nbest.pieces_size()); for (const auto &sp : nbest.pieces()) { result.emplace_back(sp.id()); } - ids->emplace_back(result, nbest.score()); + ids->emplace_back(std::move(result), nbest.score()); } return util::OkStatus(); @@ -587,7 +655,7 @@ util::Status SentencePieceProcessor::PopulateSentencePieceText( // Decomposes an unknown piece into UTF-8 bytes for (size_t i = 0; i < w.size(); ++i) { // Create a byte piece - const char b = w[i]; + const uint8_t b = static_cast(w[i]); SentencePieceText::SentencePiece *sp = spt->add_pieces(); std::string &piece = *sp->mutable_piece(); piece = ByteToPiece(b); @@ -772,9 +840,9 @@ util::Status SentencePieceProcessor::Decode( SentencePieceText *spt) const { RET_CHECK_STATUS_PROTO(spt); - const char *unk_surface = kDefaultUnknownSymbol; + absl::string_view unk_surface = kDefaultUnknownSymbol; if (model_proto_ && model_proto_->trainer_spec().has_unk_surface()) - unk_surface = model_proto_->trainer_spec().unk_surface().c_str(); + unk_surface = model_proto_->trainer_spec().unk_surface(); // Returns decoded piece and a boolean indicating if the function has consumed // a bos whitespace token (a piece starting with a kSpaceSymbol). This is used @@ -787,7 +855,7 @@ util::Status SentencePieceProcessor::Decode( return std::make_pair("", false); // invisible symbol. } else if (IsUnknown(id)) { if (IdToPiece(id) == piece) { // - return std::make_pair(unk_surface, false); + return std::make_pair(std::string(unk_surface), false); } else { // return piece when piece is not . return std::make_pair(std::string(piece), false); } @@ -937,9 +1005,8 @@ util::Status SentencePieceProcessor::Decode(const std::vector &ids, util::Status SentencePieceProcessor::Normalize(absl::string_view input, std::string *normalized) const { - std::vector norm_to_orig; RET_CHECK(normalizer_); - return normalizer_->Normalize(input, normalized, &norm_to_orig); + return normalizer_->Normalize(input, normalized, nullptr); } util::Status SentencePieceProcessor::Normalize( @@ -971,6 +1038,15 @@ const std::string &SentencePieceProcessor::IdToPiece(int id) const { return model_->IdToPiece(id); } +bool SentencePieceProcessor::SafeIdToPiece(int id, std::string *piece) const { + RET_CHECK_OR_RETURN_DEFAULT(false); + if (id < 0 || id >= model_->GetPieceSize()) { + return false; + } + *piece = IdToPiece(id); + return true; +} + float SentencePieceProcessor::GetScore(int id) const { RET_CHECK_OR_RETURN_DEFAULT(0.0); return model_->GetScore(id); @@ -997,27 +1073,23 @@ bool SentencePieceProcessor::IsByte(int id) const { } int SentencePieceProcessor::unk_id() const { - const int id = PieceToId(absl::string_view(model_->unk_piece().data())); - if (IsUnknown(id)) return id; - return -1; + const int id = PieceToId(model_->unk_piece()); + return IsUnknown(id) ? id : -1; } int SentencePieceProcessor::bos_id() const { - const int id = PieceToId(absl::string_view(model_->bos_piece().data())); - if (IsControl(id)) return id; - return -1; + const int id = PieceToId(model_->bos_piece()); + return IsControl(id) ? id : -1; } int SentencePieceProcessor::eos_id() const { - const int id = PieceToId(absl::string_view(model_->eos_piece().data())); - if (IsControl(id)) return id; - return -1; + const int id = PieceToId(model_->eos_piece()); + return IsControl(id) ? id : -1; } int SentencePieceProcessor::pad_id() const { - const int id = PieceToId(absl::string_view(model_->pad_piece().data())); - if (IsControl(id)) return id; - return -1; + const int id = PieceToId(model_->pad_piece()); + return IsControl(id) ? id : -1; } // static @@ -1032,7 +1104,7 @@ util::Status SentencePieceProcessor::ApplyExtraOptions( break; case EOS: { auto *piece = spt->add_pieces(); - piece->set_id(PieceToId(absl::string_view(model_->eos_piece().data()))); + piece->set_id(PieceToId(model_->eos_piece())); piece->set_piece(model_->eos_piece().data(), model_->eos_piece().size()); piece->set_begin(spt->text().size()); @@ -1045,7 +1117,7 @@ util::Status SentencePieceProcessor::ApplyExtraOptions( array->SwapElements(i - 1, i); } auto *piece = array->Mutable(0); - piece->set_id(PieceToId(absl::string_view(model_->bos_piece().data()))); + piece->set_id(PieceToId(model_->bos_piece())); piece->set_piece(model_->bos_piece().data(), model_->bos_piece().size()); piece->set_begin(0); @@ -1061,6 +1133,7 @@ util::Status SentencePieceProcessor::ApplyExtraOptions( } } break; default: + spt->Clear(); return util::InternalError("unknown extra_option type."); } } @@ -1092,13 +1165,11 @@ util::Status SentencePieceProcessor::ParseExtraOptions( extra_options->push_back(it->second); if (it->second == SentencePieceProcessor::BOS) { - RET_CHECK( - !IsUnknown(PieceToId(absl::string_view(model_->bos_piece().data())))) + RET_CHECK(!IsUnknown(PieceToId(model_->bos_piece()))) << "id for `" << model_->bos_piece() << "` is not defined."; } if (it->second == SentencePieceProcessor::EOS) { - RET_CHECK( - !IsUnknown(PieceToId(absl::string_view(model_->eos_piece().data())))) + RET_CHECK(!IsUnknown(PieceToId(model_->eos_piece()))) << "id for `" << model_->eos_piece() << "` is not defined."; } } diff --git a/src/sentencepiece_processor.h b/src/sentencepiece_processor.h index 795a7755..cf7c6fed 100644 --- a/src/sentencepiece_processor.h +++ b/src/sentencepiece_processor.h @@ -394,7 +394,7 @@ class SentencePieceProcessor { // Given a UTF8 input, encodes it into SentencePieceText. // // When using these APIs, sentencepiece.pb.h header files must be included. - // We can also use ImutableSentencePieceText as follows. + // We can also use ImmutableSentencePieceText as follows. // // ImmutableSentencePieceText spt; // Encode("hello", spt.mutable_proto()).IgnoreError(); @@ -645,6 +645,10 @@ class SentencePieceProcessor { // Returns the string representation of vocab with `id`. virtual const std::string &IdToPiece(int id) const; + // Returns the string representation of vocab with `id`. + // Returns false when id is out of range. + virtual bool SafeIdToPiece(int id, std::string *piece) const; + // Returns the score of `id`. // Usually score is an emission log probability of unigram language // model. diff --git a/src/sentencepiece_trainer.cc b/src/sentencepiece_trainer.cc index 37189685..2d7ce715 100644 --- a/src/sentencepiece_trainer.cc +++ b/src/sentencepiece_trainer.cc @@ -134,9 +134,7 @@ util::Status SentencePieceTrainer::MergeSpecsFromArgs( RET_CHECK(normalizer_spec) << "`normalizer_spec` must not be null."; RET_CHECK(denormalizer_spec) << "`denormalizer_spec` must not be null."; - for (const auto &it : kwargs) { - const auto &key = it.first; - const auto &value = it.second; + for (const auto &[key, value] : kwargs) { // Exceptions. if (key == "normalization_rule_name") { normalizer_spec->set_name(value); diff --git a/src/unigram_model.cc b/src/unigram_model.cc index 820809c3..93e6172e 100644 --- a/src/unigram_model.cc +++ b/src/unigram_model.cc @@ -1013,6 +1013,7 @@ EncodeResult Model::EncodeOptimized(absl::string_view normalized) const { } // Backtrack to identify the best path. EncodeResult results; + results.reserve(size / 4 + 1); int ends_at = size; while (ends_at > 0) { const auto &node = best_path_ends_at[ends_at];