Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tensorflow_text/core/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@ load("//tensorflow_text:tftext.bzl", "tf_cc_library", "tflite_cc_library")
licenses(["notice"])

# Visibility rules
package(default_visibility = ["//visibility:public"])
package(
default_applicable_licenses = ["//tensorflow_text:license"],
default_visibility = ["//visibility:public"],
)

exports_files(["LICENSE"])

Expand Down Expand Up @@ -348,6 +351,7 @@ cc_library(
],
deps = [
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/types:span",
],
)

Expand Down
14 changes: 14 additions & 0 deletions tensorflow_text/core/kernels/darts_clone_trie_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,20 @@ TEST(DartsCloneTrieBuildError, NegativeValues) {
StatusIs(util::error::INVALID_ARGUMENT));
}

TEST(DartsCloneTrieTest, OutOfBoundsAccessIsRejected) {
std::vector<std::string> vocab_tokens{"def", "\xe1\xb8\x8aZZ", "Abc"};
ASSERT_OK_AND_ASSIGN(std::vector<uint32_t> trie_array,
BuildDartsCloneTrie(vocab_tokens));
// Wrap using a constrained span to emulate an out-of-bounds access attempts.
auto span = absl::MakeSpan(trie_array.data(), 1);
ASSERT_OK_AND_ASSIGN(DartsCloneTrieWrapper trie,
DartsCloneTrieWrapper::Create(span));

DartsCloneTrieWrapper::TraversalCursor cursor =
trie.CreateTraversalCursorPointToRoot();
EXPECT_FALSE(trie.TryTraverseOneStep(cursor, 'd'));
}

} // namespace trie_utils
} // namespace text
} // namespace tensorflow
45 changes: 36 additions & 9 deletions tensorflow_text/core/kernels/darts_clone_trie_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <string.h>

#include "absl/status/statusor.h"
#include "absl/types/span.h"

namespace tensorflow {
namespace text {
Expand All @@ -51,16 +52,28 @@ class DartsCloneTrieWrapper {
uint32_t unit = 0;
};

// Constructs an instance by passing in the pointer to the trie array data.
// Constructs an instance by passing in the span of the trie array data.
// The caller needs to make sure that 'trie_array' points to a valid structure
// returned by darts_clone trie builder. The caller also needs to maintain the
// availability of 'trie_array' throughout the lifetime of this instance.
static absl::StatusOr<DartsCloneTrieWrapper> Create(
absl::Span<const uint32_t> trie_array) {
if (trie_array.empty() || trie_array.data() == nullptr) {
return absl::InvalidArgumentError("trie_array is empty or nullptr.");
}
return DartsCloneTrieWrapper(trie_array);
}

// Legacy compatibility: constructs an instance from a raw pointer, assuming
// maximum uint32_t bound to ensure compilation of existing callers, but
// callers should strongly prefer using absl::Span for memory safety.
static absl::StatusOr<DartsCloneTrieWrapper> Create(
const uint32_t* trie_array) {
if (trie_array == nullptr) {
return absl::InvalidArgumentError("trie_array is nullptr.");
}
return DartsCloneTrieWrapper(trie_array);
size_t max_len = static_cast<size_t>(UINT32_MAX);
return DartsCloneTrieWrapper(absl::MakeSpan(trie_array, max_len));
}

// Creates a cursor pointing to the root.
Expand All @@ -70,20 +83,28 @@ class DartsCloneTrieWrapper {

// Creates a cursor pointing to the 'node_id'.
TraversalCursor CreateTraversalCursor(uint32_t node_id) {
if (node_id >= trie_array_.size()) {
return {0, 0};
}
return {node_id, trie_array_[node_id]};
}

// Sets the cursor to point to 'node_id'.
void SetTraversalCursor(TraversalCursor& cursor, uint32_t node_id) {
cursor.node_id = node_id;
cursor.unit = trie_array_[node_id];
if (node_id < trie_array_.size()) {
cursor.node_id = node_id;
cursor.unit = trie_array_[node_id];
}
}

// Traverses one step from 'cursor' following 'ch'. If successful (i.e., there
// exists such an edge), moves 'cursor' to the new node and returns true.
// Otherwise, does nothing (i.e., 'cursor' is not changed) and returns false.
bool TryTraverseOneStep(TraversalCursor& cursor, unsigned char ch) const {
const uint32_t next_node_id = cursor.node_id ^ offset(cursor.unit) ^ ch;
if (next_node_id >= trie_array_.size()) {
return false;
}
const uint32_t next_node_unit = trie_array_[next_node_id];
if (label(next_node_unit) != ch) {
return false;
Expand All @@ -108,15 +129,18 @@ class DartsCloneTrieWrapper {
if (!has_leaf(cursor.unit)) {
return false;
}
const uint32_t value_unit =
trie_array_[cursor.node_id ^ offset(cursor.unit)];
const uint32_t value_node_id = cursor.node_id ^ offset(cursor.unit);
if (value_node_id >= trie_array_.size()) {
return false;
}
const uint32_t value_unit = trie_array_[value_node_id];
out_data = value(value_unit);
return true;
}

private:
// Use Create() instead of the constructor.
explicit DartsCloneTrieWrapper(const uint32_t* trie_array)
explicit DartsCloneTrieWrapper(absl::Span<const uint32_t> trie_array)
: trie_array_(trie_array) {}

// The actual implementation of TryTraverseSeveralSteps.
Expand All @@ -127,6 +151,9 @@ class DartsCloneTrieWrapper {
for (; size > 0; --size, ++ptr) {
const unsigned char ch = static_cast<const unsigned char>(*ptr);
cur_id ^= offset(cur_unit) ^ ch;
if (cur_id >= trie_array_.size()) {
return false;
}
cur_unit = trie_array_[cur_id];
if (label(cur_unit) != ch) {
return false;
Expand Down Expand Up @@ -157,8 +184,8 @@ class DartsCloneTrieWrapper {
return static_cast<int>(unit & 0x7fffffff);
}

// The pointer to the darts trie array.
const uint32_t* trie_array_;
// The dart trie array represented as a span for bounds awareness.
absl::Span<const uint32_t> trie_array_;
};

} // namespace trie_utils
Expand Down
Loading