Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class FastWordpieceBuilder {
trie_array_[node_id] &= 0xFFFFFEFF;
}

absl::optional<StringVocab> vocab_;
std::unique_ptr<StringVocab> vocab_;

int max_bytes_per_token_ = -1;

Expand Down Expand Up @@ -264,7 +264,7 @@ absl::Status FastWordpieceBuilder::BuildModel(
no_pretokenization_ = no_pretokenization;
support_detokenization_ = support_detokenization;

vocab_.emplace(vocab);
vocab_ = std::make_unique<StringVocab>(vocab);
if (vocab_->Size() != vocab.size()) {
return absl::FailedPreconditionError(
"Tokens in the vocabulary must be unique.");
Expand Down Expand Up @@ -830,7 +830,7 @@ absl::Status FastWordpieceBuilder::PrecomputeResultForSuffixIndicator() {
LookupStatus status = WordpieceTokenize(
suffix_indicator_, max_bytes_per_token_, /*max_chars_per_subtoken=*/-1,
suffix_indicator_, /*use_unknown_token=*/true, unk_token_,
/*split_unknown_characters=*/false, &vocab_.value(), &subwords,
/*split_unknown_characters=*/false, vocab_.get(), &subwords,
&begin_offset, &end_offset, &num_word_pieces);
precomputed_result_for_suffix_indicator_.reserve(subwords.size());
if (!status.success) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class PhraseBuilder {
absl::StatusOr<std::string> ExportToFlatBuffer() const;

private:
absl::optional<StringVocab> vocab_;
std::unique_ptr<StringVocab> vocab_;
std::vector<uint32_t> trie_data_;
std::string unk_token_;
int unk_token_id_;
Expand All @@ -64,7 +64,7 @@ absl::Status PhraseBuilder::BuildModel(const std::vector<std::string>& vocab,
prob_ = prob;
split_end_punctuation_ = split_end_punctuation;

vocab_.emplace(vocab);
vocab_ = std::make_unique<StringVocab>(vocab);
if (vocab_->Size() != vocab.size()) {
return absl::FailedPreconditionError(
"Tokens in the vocabulary must be unique.");
Expand Down
1 change: 1 addition & 0 deletions tensorflow_text/core/kernels/string_vocab.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ namespace text {

StringVocab::StringVocab(const std::vector<std::string>& vocab)
: vocab_(vocab) {
index_map_.reserve(vocab.size());
for (int i = 0; i < vocab.size(); ++i) {
index_map_[vocab_[i]] = i;
}
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_text/core/kernels/string_vocab.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ namespace text {
class StringVocab : public WordpieceVocab {
public:
explicit StringVocab(const std::vector<std::string>& vocab);
StringVocab(const StringVocab&) = delete;
StringVocab& operator=(const StringVocab&) = delete;
LookupStatus Contains(absl::string_view key, bool* value) const override;
absl::optional<int> LookupId(absl::string_view key) const;
// Returns the key of `vocab_id` or empty if `vocab_id` is not valid.
Expand Down