From 51a83634c28365c3d52d886a7e33a2b061d87e04 Mon Sep 17 00:00:00 2001 From: "TF.Text Team" Date: Fri, 1 May 2026 09:57:09 -0700 Subject: [PATCH] Mitigates a critical Out-of-Bounds (OOB) read vulnerability by transitioning internal trie references to absl::Span and validating array lengths before dynamic offset queries. Added explicit boundary safety unit tests to verify mitigation. PiperOrigin-RevId: 908780954 --- tensorflow_text/core/kernels/BUILD | 8 ++- .../core/kernels/darts_clone_trie_test.cc | 14 ++++++ .../core/kernels/darts_clone_trie_wrapper.h | 50 +++++++++++++++---- 3 files changed, 62 insertions(+), 10 deletions(-) diff --git a/tensorflow_text/core/kernels/BUILD b/tensorflow_text/core/kernels/BUILD index 8aba03626..743530abd 100644 --- a/tensorflow_text/core/kernels/BUILD +++ b/tensorflow_text/core/kernels/BUILD @@ -11,7 +11,10 @@ load("//tensorflow_text:tftext.bzl", "tf_cc_library", "tflite_cc_library") licenses(["notice"]) # Visibility rules -package(default_visibility = ["//visibility:public"]) +package( + default_applicable_licenses = ["//tensorflow_text:license"], + default_visibility = ["//visibility:public"], +) exports_files(["LICENSE"]) @@ -347,7 +350,10 @@ cc_library( "darts_clone_trie_wrapper.h", ], deps = [ + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow_text/core/kernels/darts_clone_trie_test.cc b/tensorflow_text/core/kernels/darts_clone_trie_test.cc index a80c28353..e74735b29 100644 --- a/tensorflow_text/core/kernels/darts_clone_trie_test.cc +++ b/tensorflow_text/core/kernels/darts_clone_trie_test.cc @@ -183,6 +183,20 @@ TEST(DartsCloneTrieBuildError, NegativeValues) { StatusIs(util::error::INVALID_ARGUMENT)); } +TEST(DartsCloneTrieTest, OutOfBoundsAccessIsRejected) { + std::vector vocab_tokens{"def", "\xe1\xb8\x8aZZ", "Abc"}; + ASSERT_OK_AND_ASSIGN(std::vector trie_array, + BuildDartsCloneTrie(vocab_tokens)); + // Wrap using a constrained span to emulate an out-of-bounds access attempts. + auto span = absl::MakeSpan(trie_array.data(), 1); + ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie, + DartsCloneTrieWrapper::Create(span)); + + DartsCloneTrieWrapper::TraversalCursor cursor = + trie.CreateTraversalCursorPointToRoot(); + EXPECT_FALSE(trie.TryTraverseOneStep(cursor, 'd')); +} + } // namespace trie_utils } // namespace text } // namespace tensorflow diff --git a/tensorflow_text/core/kernels/darts_clone_trie_wrapper.h b/tensorflow_text/core/kernels/darts_clone_trie_wrapper.h index 43067ec1b..ec0826e22 100644 --- a/tensorflow_text/core/kernels/darts_clone_trie_wrapper.h +++ b/tensorflow_text/core/kernels/darts_clone_trie_wrapper.h @@ -30,7 +30,10 @@ #include #include +#include "absl/base/attributes.h" +#include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" namespace tensorflow { namespace text { @@ -51,16 +54,31 @@ class DartsCloneTrieWrapper { uint32_t unit = 0; }; - // Constructs an instance by passing in the pointer to the trie array data. + // Constructs an instance by passing in the span of the trie array data. // The caller needs to make sure that 'trie_array' points to a valid structure // returned by darts_clone trie builder. The caller also needs to maintain the // availability of 'trie_array' throughout the lifetime of this instance. + static absl::StatusOr Create( + absl::Span trie_array) { + if (trie_array.empty() || trie_array.data() == nullptr) { + return absl::InvalidArgumentError("trie_array is empty or nullptr."); + } + return DartsCloneTrieWrapper(trie_array); + } + + // Legacy compatibility: constructs an instance from a raw pointer, assuming + // maximum uint32_t bound to ensure compilation of existing callers, but + // callers should strongly prefer using absl::Span for memory safety. + // This method remains UNPROTECTED against out-of-bounds accesses because the + // size of the trie_array is unknown. + ABSL_DEPRECATED("Use Create(absl::Span) instead.") static absl::StatusOr Create( const uint32_t* trie_array) { if (trie_array == nullptr) { return absl::InvalidArgumentError("trie_array is nullptr."); } - return DartsCloneTrieWrapper(trie_array); + size_t max_len = static_cast(UINT32_MAX); + return DartsCloneTrieWrapper(absl::MakeSpan(trie_array, max_len)); } // Creates a cursor pointing to the root. @@ -70,13 +88,18 @@ class DartsCloneTrieWrapper { // Creates a cursor pointing to the 'node_id'. TraversalCursor CreateTraversalCursor(uint32_t node_id) { + if (node_id >= trie_array_.size()) { + return {0, 0}; + } return {node_id, trie_array_[node_id]}; } // Sets the cursor to point to 'node_id'. void SetTraversalCursor(TraversalCursor& cursor, uint32_t node_id) { - cursor.node_id = node_id; - cursor.unit = trie_array_[node_id]; + if (node_id < trie_array_.size()) { + cursor.node_id = node_id; + cursor.unit = trie_array_[node_id]; + } } // Traverses one step from 'cursor' following 'ch'. If successful (i.e., there @@ -84,6 +107,9 @@ class DartsCloneTrieWrapper { // Otherwise, does nothing (i.e., 'cursor' is not changed) and returns false. bool TryTraverseOneStep(TraversalCursor& cursor, unsigned char ch) const { const uint32_t next_node_id = cursor.node_id ^ offset(cursor.unit) ^ ch; + if (next_node_id >= trie_array_.size()) { + return false; + } const uint32_t next_node_unit = trie_array_[next_node_id]; if (label(next_node_unit) != ch) { return false; @@ -108,15 +134,18 @@ class DartsCloneTrieWrapper { if (!has_leaf(cursor.unit)) { return false; } - const uint32_t value_unit = - trie_array_[cursor.node_id ^ offset(cursor.unit)]; + const uint32_t value_node_id = cursor.node_id ^ offset(cursor.unit); + if (value_node_id >= trie_array_.size()) { + return false; + } + const uint32_t value_unit = trie_array_[value_node_id]; out_data = value(value_unit); return true; } private: // Use Create() instead of the constructor. - explicit DartsCloneTrieWrapper(const uint32_t* trie_array) + explicit DartsCloneTrieWrapper(absl::Span trie_array) : trie_array_(trie_array) {} // The actual implementation of TryTraverseSeveralSteps. @@ -127,6 +156,9 @@ class DartsCloneTrieWrapper { for (; size > 0; --size, ++ptr) { const unsigned char ch = static_cast(*ptr); cur_id ^= offset(cur_unit) ^ ch; + if (cur_id >= trie_array_.size()) { + return false; + } cur_unit = trie_array_[cur_id]; if (label(cur_unit) != ch) { return false; @@ -157,8 +189,8 @@ class DartsCloneTrieWrapper { return static_cast(unit & 0x7fffffff); } - // The pointer to the darts trie array. - const uint32_t* trie_array_; + // The dart trie array represented as a span for bounds awareness. + absl::Span trie_array_; }; } // namespace trie_utils