diff --git a/tensorflow_text/core/kernels/BUILD b/tensorflow_text/core/kernels/BUILD index e1dd0ec82..6b4bbda28 100644 --- a/tensorflow_text/core/kernels/BUILD +++ b/tensorflow_text/core/kernels/BUILD @@ -29,6 +29,16 @@ cc_library( ], ) +cc_library( + name = "row_splits_validator", + hdrs = ["row_splits_validator.h"], + compatible_with = ["//buildenv/target:prod"], + deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + ], +) + cc_test( name = "boise_offset_converter_test", size = "small", @@ -62,6 +72,7 @@ tf_cc_library( ], deps = [ ":boise_offset_converter", + ":row_splits_validator", "@com_google_absl//absl/status", # lite/kernels/shim:op_kernel tensorflow dep, # lite/kernels/shim:shape tensorflow dep, @@ -112,6 +123,7 @@ tf_cc_library( ], deps = [ ":byte_splitter", + ":row_splits_validator", "@com_google_absl//absl/status", # lite/kernels/shim:op_kernel tensorflow dep, # lite/kernels/shim:shape tensorflow dep, @@ -502,6 +514,7 @@ cc_library( hdrs = ["fast_wordpiece_tokenizer_kernel_template.h"], deps = [ ":fast_wordpiece_tokenizer", + ":row_splits_validator", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", # lite/kernels/shim:op_kernel tensorflow dep, @@ -615,6 +628,7 @@ tf_cc_library( # tf/platform:tstring tensorflow dep, ], deps = [ + ":row_splits_validator", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -801,6 +815,7 @@ tf_cc_library( ], deps = [ ":round_robin_trimmer", + ":row_splits_validator", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", @@ -1020,6 +1035,7 @@ tf_cc_library( # tf:lib tensorflow dep, ], deps = [ + ":row_splits_validator", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", @@ -1288,6 +1304,7 @@ cc_library( hdrs = ["phrase_tokenizer_kernel_template.h"], deps = [ ":phrase_tokenizer", + ":row_splits_validator", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", # lite/kernels/shim:op_kernel tensorflow dep, diff --git a/tensorflow_text/core/kernels/boise_offset_converter_kernel_template.h b/tensorflow_text/core/kernels/boise_offset_converter_kernel_template.h index f49f059aa..7f504219c 100644 --- a/tensorflow_text/core/kernels/boise_offset_converter_kernel_template.h +++ b/tensorflow_text/core/kernels/boise_offset_converter_kernel_template.h @@ -16,16 +16,17 @@ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BOISE_OFFSET_CONVERTER_KERNEL_TEMPLATE_H_ #include -#include -#include +#include #include #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/platform/tstring.h" #include "tensorflow/lite/kernels/shim/op_kernel.h" #include "tensorflow/lite/kernels/shim/shape.h" #include "tensorflow/lite/kernels/shim/status_macros.h" #include "tensorflow_text/core/kernels/boise_offset_converter.h" +#include "tensorflow_text/core/kernels/row_splits_validator.h" namespace tensorflow { namespace text { @@ -304,6 +305,16 @@ absl::Status OffsetsToBoiseTagsOp::Invoke(InvokeContext* context) { } } + SH_RETURN_IF_ERROR(ValidateRowSplits( + absl::MakeConstSpan(input_token_begin_row_splits_vec.Ptr(), + input_token_begin_row_splits_vec.Dim(0)), + input_token_begin_offsets_vec.Dim(0))); + + SH_RETURN_IF_ERROR(ValidateRowSplits( + absl::MakeConstSpan(input_span_begin_row_splits_vec.Ptr(), + input_span_begin_row_splits_vec.Dim(0)), + input_span_begin_offsets_vec.Dim(0))); + // Outputs std::vector boise_tags; std::vector input_token_begin_offsets_vec_i; @@ -562,6 +573,11 @@ absl::Status BoiseTagsToOffsetsOp::Invoke(InvokeContext* context) { } } + SH_RETURN_IF_ERROR(ValidateRowSplits( + absl::MakeConstSpan(input_token_begin_row_splits_vec.Ptr(), + input_token_begin_row_splits_vec.Dim(0)), + input_token_begin_offsets_vec.Dim(0))); + // Outputs std::vector span_begin_offsets; std::vector span_end_offsets; diff --git a/tensorflow_text/core/kernels/byte_splitter_kernel_template.h b/tensorflow_text/core/kernels/byte_splitter_kernel_template.h index 77ab2b1ba..016e02a54 100644 --- a/tensorflow_text/core/kernels/byte_splitter_kernel_template.h +++ b/tensorflow_text/core/kernels/byte_splitter_kernel_template.h @@ -15,15 +15,18 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_SPLITTER_KERNEL_TEMPLATE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_BYTE_SPLITTER_KERNEL_TEMPLATE_H_ -#include +#include +#include #include #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "tensorflow/core/platform/tstring.h" #include "tensorflow/lite/kernels/shim/op_kernel.h" #include "tensorflow/lite/kernels/shim/shape.h" #include "tensorflow/lite/kernels/shim/status_macros.h" #include "tensorflow_text/core/kernels/byte_splitter.h" +#include "tensorflow_text/core/kernels/row_splits_validator.h" namespace tensorflow { namespace text { @@ -277,6 +280,13 @@ template context->GetInput(kInputRowSplits)); const auto in_splits = in_splits_view->template As(); + if (starts.Dim(0) != ends.Dim(0)) { + return absl::InvalidArgumentError( + "starts and ends must have the same size."); + } + SH_RETURN_IF_ERROR(ValidateRowSplits( + absl::MakeConstSpan(in_splits.Ptr(), in_splits.Dim(0)), starts.Dim(0))); + ByteSplitter splitter; // Outputs diff --git a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h index efc26197a..ea1142983 100644 --- a/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h +++ b/tensorflow_text/core/kernels/fast_wordpiece_tokenizer_kernel_template.h @@ -15,11 +15,19 @@ #ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_KERNEL_TEMPLATE_H_ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_FAST_WORDPIECE_TOKENIZER_KERNEL_TEMPLATE_H_ +#include +#include +#include + #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "tensorflow/core/platform/tstring.h" #include "tensorflow/lite/kernels/shim/op_kernel.h" +#include "tensorflow/lite/kernels/shim/shape.h" #include "tensorflow/lite/kernels/shim/status_macros.h" #include "tensorflow_text/core/kernels/fast_wordpiece_tokenizer.h" +#include "tensorflow_text/core/kernels/row_splits_validator.h" namespace tensorflow { namespace text { @@ -147,7 +155,7 @@ absl::Status FastWordpieceTokenizeWithOffsetsOp::Invoke( // Create() is very cheap. auto fast_wordpiece_tokenizer = ::tensorflow::text::FastWordpieceTokenizer::Create( - wp_model->template Data().data()); + wp_model->template Data().data()); SH_RETURN_IF_ERROR(fast_wordpiece_tokenizer.status()); // TODO(xysong): Optimize based on which information below is requested. @@ -180,13 +188,13 @@ absl::Status FastWordpieceTokenizeWithOffsetsOp::Invoke( SH_RETURN_IF_ERROR(this->template FillOutputTensor( subwords, kOutputSubwords, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor( + SH_RETURN_IF_ERROR(this->template FillOutputTensor( subword_ids, kOutputIds, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor( + SH_RETURN_IF_ERROR(this->template FillOutputTensor( row_splits, kOutputRowSplits, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor( + SH_RETURN_IF_ERROR(this->template FillOutputTensor( begin_offset, kStartValues, context)); - SH_RETURN_IF_ERROR(this->template FillOutputTensor( + SH_RETURN_IF_ERROR(this->template FillOutputTensor( end_offset, kEndValues, context)); return absl::OkStatus(); @@ -311,7 +319,11 @@ absl::Status FastWordpieceDetokenizeOp::Invoke(InvokeContext* context) { SH_ASSIGN_OR_RETURN(const auto input_row_splits, context->GetInput(kInputRowSplits)); - const auto& row_splits_vec = input_row_splits->template As(); + const auto& row_splits_vec = input_row_splits->template As(); + + SH_RETURN_IF_ERROR(ValidateRowSplits( + absl::MakeConstSpan(row_splits_vec.Ptr(), row_splits_vec.Dim(0)), + values_vec.Dim(0))); SH_ASSIGN_OR_RETURN(const auto wp_model, context->GetInput(kWpModel)); // OK to create on every call because FastWordpieceTokenizer is a @@ -319,7 +331,7 @@ absl::Status FastWordpieceDetokenizeOp::Invoke(InvokeContext* context) { // Create() is very cheap. auto fast_wordpiece_tokenizer = ::tensorflow::text::FastWordpieceTokenizer::Create( - wp_model->template Data().data()); + wp_model->template Data().data()); SH_RETURN_IF_ERROR(fast_wordpiece_tokenizer.status()); std::vector sentences; diff --git a/tensorflow_text/core/kernels/ngrams_kernel_template.h b/tensorflow_text/core/kernels/ngrams_kernel_template.h index 1a8a3fc8f..1a76bd735 100644 --- a/tensorflow_text/core/kernels/ngrams_kernel_template.h +++ b/tensorflow_text/core/kernels/ngrams_kernel_template.h @@ -29,14 +29,23 @@ limitations under the License. #ifndef TENSORFLOW_TEXT_CORE_KERNELS_NGRAMS_KERNEL_TEMPLATE_H_ #define TENSORFLOW_TEXT_CORE_KERNELS_NGRAMS_KERNEL_TEMPLATE_H_ +#include +#include +#include +#include +#include + #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "tensorflow/core/platform/tstring.h" #include "tensorflow/lite/kernels/shim/op_kernel.h" +#include "tensorflow/lite/kernels/shim/shape.h" #include "tensorflow/lite/kernels/shim/status_macros.h" #include "tensorflow/lite/kernels/shim/tensor_view.h" +#include "tensorflow_text/core/kernels/row_splits_validator.h" namespace tensorflow { namespace text { @@ -191,6 +200,8 @@ class NgramsStringJoin : public tflite::shim::OpKernelShimShape()))); const auto input_buffer = input_tensor_row_splits->template Data(); + SH_RETURN_IF_ERROR(ValidateRowSplits( + absl::MakeConstSpan(input_buffer.data(), input_buffer.size()))); const auto output_buffer = output_tensor_row_splits->template Data(); std::memcpy(output_buffer.data(), input_buffer.data(), @@ -214,6 +225,12 @@ class NgramsStringJoin : public tflite::shim::OpKernelShimtemplate Data(); + if (ctx->NumOutputs() != 1) { + SH_RETURN_IF_ERROR(ValidateRowSplits( + absl::MakeConstSpan(input_row_splits, n_row_splits), + input_values_data.size())); + } + // Create ngrams by looping through the innermost input splits. std::vector buffer; for (int i = 0; i < n_row_splits - 1; ++i) { @@ -247,8 +264,8 @@ class NgramsStringJoin : public tflite::shim::OpKernelShim +#include +#include + #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "tensorflow/core/platform/tstring.h" #include "tensorflow/lite/kernels/shim/op_kernel.h" +#include "tensorflow/lite/kernels/shim/shape.h" #include "tensorflow/lite/kernels/shim/status_macros.h" #include "tensorflow_text/core/kernels/phrase_tokenizer.h" +#include "tensorflow_text/core/kernels/row_splits_validator.h" namespace tensorflow { namespace text { @@ -126,7 +134,7 @@ absl::Status PhraseTokenizeOp::Invoke(InvokeContext* context) { // lightweight, memory-mapped wrapper on `phrase_model` tensor, and thus // Create() is very cheap. auto phrase_tokenizer = ::tensorflow::text::PhraseTokenizer::Create( - phrase_model->template Data().data()); + phrase_model->template Data().data()); SH_RETURN_IF_ERROR(phrase_tokenizer.status()); std::vector subwords; @@ -159,13 +167,13 @@ absl::Status PhraseTokenizeOp::Invoke(InvokeContext* context) { kOutputIds, Shape({static_cast( subword_ids.size())}))); /* same shape as `output_subwords` */ - auto output_ids_vec = output_ids->template As(); + auto output_ids_vec = output_ids->template As(); SH_ASSIGN_OR_RETURN( auto output_row_splits, context->GetOutput(kOutputRowSplits, Shape({static_cast(row_splits.size())}))); - auto output_row_splits_vec = output_row_splits->template As(); + auto output_row_splits_vec = output_row_splits->template As(); for (int i = 0; i < subwords.size(); ++i) { output_subwords_vec(i) = subwords[i]; @@ -299,14 +307,18 @@ absl::Status PhraseDetokenizeOp::Invoke(InvokeContext* context) { SH_ASSIGN_OR_RETURN(const auto input_row_splits, context->GetInput(kInputRowSplits)); - const auto& row_splits_vec = input_row_splits->template As(); + const auto& row_splits_vec = input_row_splits->template As(); + + SH_RETURN_IF_ERROR(ValidateRowSplits( + absl::MakeConstSpan(row_splits_vec.Ptr(), row_splits_vec.Dim(0)), + values_vec.Dim(0))); SH_ASSIGN_OR_RETURN(const auto phrase_model, context->GetInput(kPhraseModel)); // OK to create on every call because PhraseTokenizer is a // lightweight, memory-mapped wrapper on `phrase_model` tensor, and thus // Create() is very cheap. auto phrase_tokenizer = ::tensorflow::text::PhraseTokenizer::Create( - phrase_model->template Data().data()); + phrase_model->template Data().data()); SH_RETURN_IF_ERROR(phrase_tokenizer.status()); std::vector sentences; diff --git a/tensorflow_text/core/kernels/round_robin_trimmer.h b/tensorflow_text/core/kernels/round_robin_trimmer.h index 5273dfa9e..fba4b3293 100644 --- a/tensorflow_text/core/kernels/round_robin_trimmer.h +++ b/tensorflow_text/core/kernels/round_robin_trimmer.h @@ -16,11 +16,12 @@ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROUND_ROBIN_TRIMMER_H_ #include +#include #include #include #include -#include "tensorflow_text/core/kernels/trimmer.h" +#include "tensorflow_text/core/kernels/trimmer.h" namespace tensorflow { namespace text { @@ -153,7 +154,7 @@ std::vector RoundRobinTrimmer::GenerateMasksInternal( std::vector masks(end - begin); auto m = masks.begin(); for (auto it = begin; it != end; ++it, ++m) { - m->reserve(it->back()); + m->reserve(std::max(static_cast(0), it->empty() ? 0 : it->back())); } // Process all batches, updating the masks a batch at a time. ProcessSplitsByBatch(begin, end, [&masks](std::vector* rows) { @@ -305,7 +306,8 @@ void RoundRobinTrimmer::ProcessSplitsByBatch( int idx = 0; for (auto i = begin; i < end; ++i, ++idx) { value_row_sizes[idx].idx = idx; - value_row_sizes[idx].size = (*i)[batch_idx + 1] - (*i)[batch_idx]; + Tsplits row_size = (*i)[batch_idx + 1] - (*i)[batch_idx]; + value_row_sizes[idx].size = row_size < 0 ? 0 : row_size; } // Perform the main processing of the batch ProcessBatch(&value_row_sizes, callback); diff --git a/tensorflow_text/core/kernels/round_robin_trimmer_kernel_template.h b/tensorflow_text/core/kernels/round_robin_trimmer_kernel_template.h index 51f17da43..01cbfb95e 100644 --- a/tensorflow_text/core/kernels/round_robin_trimmer_kernel_template.h +++ b/tensorflow_text/core/kernels/round_robin_trimmer_kernel_template.h @@ -16,15 +16,17 @@ #define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROUND_ROBIN_TRIMMER_KERNEL_TEMPLATE_H_ #include -#include +#include #include #include "absl/status/status.h" +#include "absl/strings/str_cat.h" #include "absl/types/span.h" #include "tensorflow/lite/kernels/shim/op_kernel.h" #include "tensorflow/lite/kernels/shim/shape.h" #include "tensorflow/lite/kernels/shim/status_macros.h" #include "tensorflow_text/core/kernels/round_robin_trimmer.h" +#include "tensorflow_text/core/kernels/row_splits_validator.h" namespace tensorflow { namespace text { @@ -147,7 +149,7 @@ template absl::Status RoundRobinTrimOp::Invoke(InvokeContext* context) { // Inputs SH_ASSIGN_OR_RETURN(const auto msl, context->GetInput(kMaxSeqLength)); - const int max_sequence_length = msl->template AsScalar(); + const int max_sequence_length = msl->template AsScalar(); std::vector> list_of_values(number_of_segments_); std::vector> list_of_splits(number_of_segments_); @@ -158,6 +160,9 @@ absl::Status RoundRobinTrimOp::Invoke(InvokeContext* context) { int row_split_idx = kInputRowSplits + number_of_segments_ - 1 + i; SH_ASSIGN_OR_RETURN(const auto rs, context->GetInput(row_split_idx)); list_of_splits[i] = rs->template Data(); + + SH_RETURN_IF_ERROR(ValidateRowSplits(list_of_splits[i], + list_of_values[i].size())); } // Compute @@ -295,13 +300,17 @@ absl::Status RoundRobinGenerateMasksOp::Invoke( InvokeContext* context) { // Inputs SH_ASSIGN_OR_RETURN(const auto msl, context->GetInput(kMaxSeqLength)); - const int max_sequence_length = msl->template AsScalar(); + const int max_sequence_length = msl->template AsScalar(); std::vector> list_of_splits(number_of_segments_); for (int i = 0; i < number_of_segments_; ++i) { int row_split_idx = kInputRowSplits + number_of_segments_ - 1 + i; SH_ASSIGN_OR_RETURN(const auto rs, context->GetInput(row_split_idx)); list_of_splits[i] = rs->template Data(); + + SH_ASSIGN_OR_RETURN(const auto fv, context->GetInput(kInputValues + i)); + SH_RETURN_IF_ERROR(ValidateRowSplits( + list_of_splits[i], fv->template Data().size())); } // Compute diff --git a/tensorflow_text/core/kernels/round_robin_trimmer_test.cc b/tensorflow_text/core/kernels/round_robin_trimmer_test.cc index 50c21e32d..513d4f4bb 100644 --- a/tensorflow_text/core/kernels/round_robin_trimmer_test.cc +++ b/tensorflow_text/core/kernels/round_robin_trimmer_test.cc @@ -14,9 +14,10 @@ #include "tensorflow_text/core/kernels/round_robin_trimmer.h" +#include +#include #include #include -#include #include #include @@ -225,6 +226,14 @@ INSTANTIATE_TEST_SUITE_P(RoundRobinTrimmerTestSuite, RoundRobinTrimmerTest, testing::ValuesIn(params)); +TEST(RoundRobinTrimmerSingleTest, NonMonotonicRowSplits) { + RoundRobinTrimmer t(10); + std::vector> input_vals = {{1, 2, 3, 4, 5}}; + std::vector> input_splits = {{0, 5, 2}}; + auto [vals, splits] = t.TrimBatch(input_vals, input_splits); + EXPECT_THAT(splits[0], ::testing::ElementsAreArray({0, 5, 5})); +} + } // namespace } // namespace text } // namespace tensorflow diff --git a/tensorflow_text/core/kernels/row_splits_validator.h b/tensorflow_text/core/kernels/row_splits_validator.h new file mode 100644 index 000000000..b9a6a276a --- /dev/null +++ b/tensorflow_text/core/kernels/row_splits_validator.h @@ -0,0 +1,52 @@ +// Copyright 2026 TF.Text Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROW_SPLITS_VALIDATOR_H_ +#define THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROW_SPLITS_VALIDATOR_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" + +namespace tensorflow { +namespace text { + +template +inline absl::Status ValidateRowSplits(absl::Span row_splits, + int64_t max_values_size = -1) { + if (row_splits.empty()) { + return absl::InvalidArgumentError("row_splits cannot be empty."); + } + if (row_splits[0] != 0) { + return absl::InvalidArgumentError("row_splits must start with 0."); + } + for (size_t i = 0; i < row_splits.size() - 1; ++i) { + if (row_splits[i + 1] < row_splits[i]) { + return absl::InvalidArgumentError( + "row_splits must be monotonically increasing."); + } + } + if (max_values_size >= 0 && row_splits.back() > max_values_size) { + return absl::InvalidArgumentError( + "row_splits values exceed the size of the values array."); + } + return absl::OkStatus(); +} + +} // namespace text +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_TEXT_CORE_KERNELS_ROW_SPLITS_VALIDATOR_H_ diff --git a/tensorflow_text/core/kernels/split_merge_tokenize_kernel.cc b/tensorflow_text/core/kernels/split_merge_tokenize_kernel.cc index 5491fab4d..487cdb219 100644 --- a/tensorflow_text/core/kernels/split_merge_tokenize_kernel.cc +++ b/tensorflow_text/core/kernels/split_merge_tokenize_kernel.cc @@ -12,20 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include +#include #include #include #include "absl/status/status.h" #include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "icu4c/source/common/unicode/uchar.h" #include "icu4c/source/common/unicode/umachine.h" #include "icu4c/source/common/unicode/utf8.h" +#include "tensorflow/compiler/xla/tsl/platform/macros.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/op_requires.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/tstring.h" +#include "tensorflow_text/core/kernels/row_splits_validator.h" namespace tensorflow { namespace text { @@ -89,7 +96,7 @@ Status TokenizeByLabel(const absl::string_view& text, bool last_character_is_break_character = false; int start = 0; bool has_new_token_generated_for_text = false; - const auto& labels = labels_tensor.unaligned_flat(); + const auto& labels = labels_tensor.unaligned_flat(); for (int i = 0; i < chars.size(); ++i) { const bool is_break_character = IsBreakChar(chars[i]); if (!is_break_character) { @@ -138,14 +145,18 @@ class SplitMergeTokenizeWithOffsetsOp : public OpKernel { " elements, got ", row_splits->dim_size(0))); - std::vector tokens; + std::vector tokens; std::vector begin_offset; std::vector end_offset; std::vector output_row_splits(1, 0); // Iterate through all the values and tokenize them. - const auto& values_vec = input_values->flat(); - const auto& row_splits_vec = row_splits->flat(); + const auto& values_vec = input_values->flat(); + const auto& row_splits_vec = row_splits->flat(); + OP_REQUIRES_OK(ctx, ValidateRowSplits( + absl::MakeConstSpan(row_splits_vec.data(), + row_splits_vec.size()), + labels->dim_size(0))); for (int i = 0; i < values_vec.size(); ++i) { // Tokenize into tokens and record the offset locations. int num_tokens = 0; @@ -160,10 +171,10 @@ class SplitMergeTokenizeWithOffsetsOp : public OpKernel { output_row_splits.push_back(num_tokens + output_row_splits.back()); } - std::vector output_tokens_shape; + std::vector output_tokens_shape; output_tokens_shape.push_back(tokens.size()); - std::vector output_row_splits_shape; + std::vector output_row_splits_shape; output_row_splits_shape.push_back(output_row_splits.size()); Tensor* output_values; @@ -177,19 +188,19 @@ class SplitMergeTokenizeWithOffsetsOp : public OpKernel { ctx->allocate_output("output_row_splits", TensorShape(output_row_splits_shape), &output_row_splits_tensor)); - auto output_row_splits_vec = output_row_splits_tensor->vec(); + auto output_row_splits_vec = output_row_splits_tensor->vec(); Tensor* start_values; OP_REQUIRES_OK(ctx, ctx->allocate_output("start_values", TensorShape(output_tokens_shape), &start_values)); - auto start_values_vec = start_values->vec(); + auto start_values_vec = start_values->vec(); Tensor* limit_values; OP_REQUIRES_OK(ctx, ctx->allocate_output("limit_values", TensorShape(output_tokens_shape), &limit_values)); - auto limit_values_vec = limit_values->vec(); + auto limit_values_vec = limit_values->vec(); for (int i = 0; i < tokens.size(); ++i) { output_values_vec(i) = tokens[i];