diff --git a/include/cuco/detail/open_addressing/robin_hood/open_addressing_ref_impl.cuh b/include/cuco/detail/open_addressing/robin_hood/open_addressing_ref_impl.cuh new file mode 100644 index 000000000..5bbbffc56 --- /dev/null +++ b/include/cuco/detail/open_addressing/robin_hood/open_addressing_ref_impl.cuh @@ -0,0 +1,2378 @@ +/* + * Copyright (c) 2023-2026, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(CUCO_HAS_CUDA_BARRIER) +#include +#endif + +#include + +namespace cuco { +namespace detail { +namespace robin_hood { + +/// Three-way insert result enum +enum class insert_result : cuda::std::int8_t { CONTINUE = 0, SUCCESS = 1, DUPLICATE = 2 }; + +/** + * @brief Helper struct to store intermediate bucket probing results. + */ +struct bucket_probing_results { + detail::equal_result state_; ///< Equal result + cuda::std::int32_t intra_bucket_index_; ///< Intra-bucket index + + /** + * @brief Constructs bucket_probing_results. + * + * @param state The three way equality result + * @param index Intra-bucket index + */ + __device__ explicit constexpr bucket_probing_results(detail::equal_result state, + cuda::std::int32_t index) noexcept + : state_{state}, intra_bucket_index_{index} + { + } +}; + +/** + * @brief Robin Hood inverse primitive for the linear probing sequence. + * + * @note Recovers a resident's probe distance ("age") from the slot it occupies: how many probing + * steps the resident sits from its own home bucket. For the linear sequence this is a single + * subtract — `(slot_base - resident_home) / stride mod num_buckets`. This is the linear-only + * overload; a `double_hashing` variant would add its own overload here (a modular inverse of the + * resident's per-key step, or a stored age), which is the single place that change lands. + * + * @tparam BucketSize Size of the bucket + * @tparam CGSize Size of CUDA Cooperative Groups + * @tparam Hash Unary callable type + * @tparam ProbeKey Type of probing key + * @tparam Extent Type of extent + * + * @param scheme The underlying linear probing scheme (supplies the hash function) + * @param resident_key The key currently residing in the slot + * @param slot_index The slot index at which `resident_key` resides + * @param upper_bound Upper bound of the iteration + * @return The resident's probe distance, in probing steps + */ +template +[[nodiscard]] __host__ __device__ constexpr typename Extent::value_type probe_distance( + linear_probing const& scheme, + ProbeKey resident_key, + typename Extent::value_type slot_index, + Extent upper_bound) noexcept +{ + using size_type = typename Extent::value_type; + size_type constexpr stride = CGSize * BucketSize; + auto const bound = static_cast(upper_bound); + auto const hash = scheme.hash_function(); + + // Home bucket base of the resident, using the same alignment as `make_iterator`. + size_type const resident_home = + cuco::detail::sanitize_hash(hash(resident_key)) % (bound / stride) * stride; + + // Bucket-strided base of the slot the resident currently occupies. The per-lane `thread_rank` + // offset (which is < stride) is stripped by the floor division so that the distance is measured + // in whole probing steps, consistent with the forward sequence. + size_type const slot_base = (slot_index / stride) * stride; + + // (slot_base - resident_home) mod capacity, expressed in probing steps. + return static_cast((slot_base + bound - resident_home) % bound) / stride; +} + +/** + * @brief Common device non-owning "ref" implementation class. + * + * @note This class should NOT be used directly. + * + * @throw If the size of the given key type is larger than `cuco::open_addressing_max_key_size` + * @throw If the size of the given slot type is larger than `cuco::open_addressing_max_slot_size` + * @throw If the given key type doesn't have unique object representations, i.e., + * `cuco::is_bitwise_comparable_v == false` + * @throw If the given payload type doesn't have unique object representations, i.e., + * `cuco::is_bitwise_comparable_v == false` + * @throw If the probing scheme type is not inherited from `cuco::detail::probing_scheme_base` + * + * @tparam Key Type used for keys. Requires `sizeof(Key) <= cuco::open_addressing_max_key_size` and + * `cuco::is_bitwise_comparable_v` + * @tparam Scope The scope in which operations will be performed by individual threads. + * @tparam KeyEqual Binary callable type used to compare two keys for equality + * @tparam ProbingScheme Probing scheme (see `include/cuco/probing_scheme.cuh` for options) + * @tparam StorageRef Storage ref type. Its `value_type` must fit in + * `cuco::open_addressing_max_slot_size`; + * payloads, if present, must be 4 or 8 bytes (or 16 with sm_90+) and satisfy + * `cuco::is_bitwise_comparable_v` + * @tparam AllowsDuplicates Flag indicating whether duplicate keys are allowed or not + */ +template +class open_addressing_ref_impl + : private open_addressing_compatible { + using storage_value_type = typename StorageRef::value_type; + + /// Determines if the container is a key/value or key-only store + static constexpr auto has_payload = not cuda::std::is_same_v; + + /// Flag indicating whether duplicate keys are allowed or not + static constexpr auto allows_duplicates = AllowsDuplicates; + + // TODO: how to re-enable this check? + // static_assert(is_bucket_extent_v, + // "Extent is not a valid cuco::bucket_extent"); + + public: + using key_type = Key; ///< Key type + using probing_scheme_type = ProbingScheme; ///< Type of probing scheme + using hasher = typename probing_scheme_type::hasher; ///< Hash function type + using storage_ref_type = StorageRef; ///< Type of storage ref + using bucket_type = typename storage_ref_type::bucket_type; ///< Bucket type + using value_type = typename storage_ref_type::value_type; ///< Storage element type + using extent_type = typename storage_ref_type::extent_type; ///< Extent type + using size_type = typename storage_ref_type::size_type; ///< Probing scheme size type + using key_equal = KeyEqual; ///< Type of key equality binary callable + using iterator = typename storage_ref_type::iterator; ///< Slot iterator type + using const_iterator = typename storage_ref_type::const_iterator; ///< Const slot iterator type + + static constexpr auto cg_size = probing_scheme_type::cg_size; ///< Cooperative group size + static constexpr auto bucket_size = + storage_ref_type::bucket_size; ///< Number of elements handled per bucket + static constexpr auto thread_scope = Scope; ///< CUDA thread scope + + // Robin Hood displacement swaps the in-flight pair into an *occupied* slot, which needs a single + // atomic CAS of the whole slot. That requires a packable slot: <= 8 bytes (atom.cas.b64), or + // padding-free and <= 16 bytes on an sm_90+ build (atom.cas.b128). A non-packable slot (e.g. a + // padded `pair`) would fall back to a split key/value CAS, which cannot move an + // occupied slot -- displacement would livelock. Reject it at compile time rather than hang. + static constexpr bool robin_hood_slot_is_single_cas = sizeof(value_type) <= 8 +#if defined(CUCO_HAS_128BIT_ATOMICS) + or cuco::detail::is_packable() +#endif + ; + static_assert(robin_hood_slot_is_single_cas, + "Robin Hood probing requires a single-CAS slot: the key+value must fit in 8 bytes, " + "or be packable (padding-free) and <= 16 bytes on an sm_90+ build. A padded slot " + "(e.g. pair) is unsupported -- displacement would livelock."); + + /** + * @brief Constructs open_addressing_ref_impl. + * + * @param empty_slot_sentinel Sentinel indicating an empty slot + * @param predicate Key equality binary callable + * @param probing_scheme Probing scheme + * @param storage_ref Non-owning ref of slot storage + */ + __host__ __device__ explicit constexpr open_addressing_ref_impl( + value_type empty_slot_sentinel, + key_equal const& predicate, + probing_scheme_type const& probing_scheme, + storage_ref_type storage_ref) noexcept + : empty_slot_sentinel_{empty_slot_sentinel}, + predicate_{ + this->extract_key(empty_slot_sentinel), this->extract_key(empty_slot_sentinel), predicate}, + probing_scheme_{probing_scheme}, + storage_ref_{storage_ref} + { + } + + /** + * @brief Constructs open_addressing_ref_impl. + * + * @param empty_slot_sentinel Sentinel indicating an empty slot + * @param erased_key_sentinel Sentinel indicating an erased key + * @param predicate Key equality binary callable + * @param probing_scheme Probing scheme + * @param storage_ref Non-owning ref of slot storage + */ + __host__ __device__ explicit constexpr open_addressing_ref_impl( + value_type empty_slot_sentinel, + key_type erased_key_sentinel, + key_equal const& predicate, + probing_scheme_type const& probing_scheme, + storage_ref_type storage_ref) noexcept + : empty_slot_sentinel_{empty_slot_sentinel}, + predicate_{this->extract_key(empty_slot_sentinel), erased_key_sentinel, predicate}, + probing_scheme_{probing_scheme}, + storage_ref_{storage_ref} + { + } + + /** + * @brief Gets the sentinel value used to represent an empty key slot. + * + * @return The sentinel value used to represent an empty key slot + */ + [[nodiscard]] __host__ __device__ constexpr key_type empty_key_sentinel() const noexcept + { + return this->predicate_.empty_sentinel_; + } + + /** + * @brief Gets the sentinel value used to represent an empty payload slot. + * + * @return The sentinel value used to represent an empty payload slot + */ + template > + [[nodiscard]] __host__ __device__ constexpr auto empty_value_sentinel() const noexcept + { + return this->extract_payload(this->empty_slot_sentinel()); + } + + /** + * @brief Gets the sentinel value used to represent an erased key slot. + * + * @return The sentinel value used to represent an erased key slot + */ + [[nodiscard]] __host__ __device__ constexpr key_type erased_key_sentinel() const noexcept + { + return this->predicate_.erased_sentinel_; + } + + /** + * @brief Gets the sentinel used to represent an empty slot. + * + * @return The sentinel value used to represent an empty slot + */ + [[nodiscard]] __host__ __device__ constexpr value_type empty_slot_sentinel() const noexcept + { + return empty_slot_sentinel_; + } + + /** + * @brief Returns the function that compares keys for equality. + * + * @return The key equality predicate + */ + [[nodiscard]] __host__ + __device__ constexpr detail::equal_wrapper + predicate() const noexcept + { + return this->predicate_; + } + + /** + * @brief Gets the key comparator. + * + * @return The comparator used to compare keys + */ + [[nodiscard]] __host__ __device__ constexpr key_equal key_eq() const noexcept + { + return this->predicate().equal_; + } + + /** + * @brief Gets the probing scheme. + * + * @return The probing scheme used for the container + */ + [[nodiscard]] __host__ __device__ constexpr probing_scheme_type probing_scheme() const noexcept + { + return probing_scheme_; + } + + /** + * @brief Gets the function(s) used to hash keys + * + * @return The function(s) used to hash keys + */ + [[nodiscard]] __host__ __device__ constexpr hasher hash_function() const noexcept + { + return this->probing_scheme().hash_function(); + } + + /** + * @brief Gets the non-owning storage ref. + * + * @return The non-owning storage ref of the container + */ + [[nodiscard]] __host__ __device__ constexpr storage_ref_type storage_ref() const noexcept + { + return storage_ref_; + } + + /** + * @brief Gets the maximum number of elements the container can hold. + * + * @return The maximum number of elements the container can hold + */ + [[nodiscard]] __host__ __device__ constexpr auto capacity() const noexcept + { + return storage_ref_.capacity(); + } + + /** + * @brief Gets the bucket extent of the current storage. + * + * @return The bucket extent. + */ + [[nodiscard]] __host__ __device__ constexpr extent_type extent() const noexcept + { + return storage_ref_.extent(); + } + + /** + * @brief Returns an iterator to one past the last slot. + * + * @return An iterator to one past the last slot + */ + [[nodiscard]] __host__ __device__ constexpr iterator end() const noexcept + { + return storage_ref_.end(); + } + + /** + * @brief Returns an iterator to one past the last slot. + * + * @return An iterator to one past the last slot + */ + [[nodiscard]] __host__ __device__ constexpr iterator end() noexcept { return storage_ref_.end(); } + + /** + * @brief Makes a copy of the current device reference using non-owned memory. + * + * This function is intended to be used to create shared memory copies of small static data + * structures, although global memory can be used as well. + * + * @tparam CG The type of the cooperative thread group + * + * @param g The cooperative thread group used to copy the data structure + * @param memory_to_use Array large enough to support `capacity` elements. Object does not take + * the ownership of the memory + */ + template + __device__ void make_copy(CG g, value_type* const memory_to_use) const noexcept + { + auto const num_slots = this->capacity(); +#if defined(CUCO_HAS_CUDA_BARRIER) +#pragma nv_diagnostic push +// Disables `barrier` initialization warning. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ cuda::barrier barrier; +#pragma nv_diagnostic pop + if (g.thread_rank() == 0) { init(&barrier, g.size()); } + g.sync(); + + cuda::memcpy_async( + g, memory_to_use, this->storage_ref().data(), sizeof(value_type) * num_slots, barrier); + + barrier.arrive_and_wait(); +#else + value_type const* const slots_ptr = this->storage_ref().data(); + for (size_type i = g.thread_rank(); i < num_slots; i += g.size()) { + memory_to_use[i] = slots_ptr[i]; + } + g.sync(); +#endif + } + + /** + * @brief Initializes the container storage. + * + * @note This function synchronizes the group `tile`. + * + * @tparam CG The type of the cooperative thread group + * + * @param tile The cooperative thread group used to initialize the container + */ + template + __device__ constexpr void initialize(CG tile) noexcept + { + auto tid = tile.thread_rank(); + auto const extent = static_cast(this->extent()); + + auto* const slots_ptr = this->storage_ref().data(); + while (tid < extent) { + slots_ptr[tid] = this->empty_slot_sentinel(); + tid += tile.size(); + } + + tile.sync(); + } + + /** + * @brief Inserts an element. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param value The element to insert + * + * @return True if the given element is successfully inserted + */ + template + __device__ bool insert(Value value) noexcept + { + static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + + auto val = this->heterogeneous_value(value); + auto key = this->extract_key(val); + + auto probing_iter = + probing_scheme_.template make_iterator(key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type probe_step = 0; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + bool retry = false; + + for (auto& slot_content : bucket_slots) { + auto const eq_res = this->predicate_.template operator()( + key, this->extract_key(slot_content)); + + if constexpr (not allows_duplicates) { + // If the key is already in the container, return false + if (eq_res == detail::equal_result::EQUAL) { return false; } + } + // Robin Hood claims only a true empty here; a tombstone carries an age and is handled as a + // resident by the displacement test below. Skipping it must gate the CAS (once claimed it + // is already consumed), so it is folded into this condition. + if (eq_res == detail::equal_result::AVAILABLE and not this->is_erased(slot_content)) { + auto const intra_bucket_index = cuda::std::distance(bucket_slots.begin(), &slot_content); + switch (attempt_insert( + this->get_slot_ptr(*probing_iter, intra_bucket_index), slot_content, val)) { + case insert_result::DUPLICATE: { + if constexpr (allows_duplicates) { + [[fallthrough]]; + } else { + return false; + } + } + case insert_result::CONTINUE: { + // Retry on a lost CAS. Plain probing keeps scanning this (now stale) bucket; Robin + // Hood must re-read it instead, so the in-flight pair is re-evaluated against the new + // occupants -- otherwise it could be placed past a slot it should have displaced, + // breaking the invariant (and therefore lookups). + retry = true; + break; + } + case insert_result::SUCCESS: return true; + } + if (retry) { break; } // leave the scan to re-read the bucket + } + + // Robin Hood swap test. A resident "richer" than the in-flight pair (a smaller probe + // distance than our current probe step) is displaced: we swap our pair into its slot, adopt + // the evicted resident, and re-probe forward. A tombstone is treated as a resident too -- + // its age comes from its payload (`robin_hood_age`) -- but picking one up *consumes* it: we + // take the slot and are done, since there is nothing to carry forward. + if (eq_res == detail::equal_result::UNEQUAL or this->is_erased(slot_content)) { + auto const intra_bucket_index = cuda::std::distance(bucket_slots.begin(), &slot_content); + auto const evicted_age = this->robin_hood_age( + slot_content, static_cast(*probing_iter + intra_bucket_index)); + if (evicted_age < probe_step) { + if (this->attempt_insert(this->get_slot_ptr(*probing_iter, intra_bucket_index), + slot_content, + val) == insert_result::SUCCESS) { + // Consuming a tombstone reuses its freed slot -- nothing to carry, so we are done. + if (this->is_erased(slot_content)) { return true; } + // Adopt the evicted pair and re-probe THIS bucket -- its bucket distance here is + // `evicted_age`, and it may belong in another slot of the same bucket: an empty + // one, or one holding an even-richer resident it can displace in turn. Re-reading + // the bucket (rather than advancing past it) is the within-bucket linear probe, + // i.e. the combined bucket+slot distance that makes displacement correct for + // bucket_size > 1. The `slot_distance` term cancels in every comparison, so it + // never appears here; it shows up only as this slot-by-slot continuation. + // `bit_cast` keeps the adoption valid for heterogeneous insert types + // (layout-compatible by contract; identity in the common case). + val = cuda::std::bit_cast(slot_content); + key = this->extract_key(val); + probe_step = evicted_age; + } + retry = + true; // re-read this bucket: re-probe with the victim, or re-evaluate a lost CAS + break; + } + } + } + + if (retry) { continue; } // re-probe (re-read this bucket, or move on after displacement) + ++probe_step; + + ++probing_iter; + if (*probing_iter == init_idx) { return false; } + } + } + + /** + * @brief Inserts an element. + * + * @tparam Value Input type which is convertible to 'value_type' + * @tparam ParentCG Type of parent Cooperative Group + * + * @param group The Cooperative Group used to perform group insert + * @param value The element to insert + * + * @return True if the given element is successfully inserted + */ + template + __device__ bool insert(cooperative_groups::thread_block_tile group, + Value value) noexcept + { + auto val = this->heterogeneous_value(value); + auto key = this->extract_key(val); + auto probing_iter = + probing_scheme_.template make_iterator(group, key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type probe_step = 0; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + auto const [state, intra_bucket_index] = [&]() { + bucket_probing_results result{detail::equal_result::UNEQUAL, -1}; + cuda::static_for([&] __device__(auto i) { + if (result.state_ == detail::equal_result::UNEQUAL) { + switch (this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i()]))) { + case detail::equal_result::AVAILABLE: { + // Robin Hood: only a true empty is AVAILABLE; a tombstone is a resident handled by + // the displacement scan below, so leave it UNEQUAL here. + bool empty_slot = not this->is_erased(bucket_slots[i()]); + if (empty_slot) { + result = bucket_probing_results{detail::equal_result::AVAILABLE, i()}; + } + break; + } + case detail::equal_result::EQUAL: { + if constexpr (!allows_duplicates) { + result = bucket_probing_results{detail::equal_result::EQUAL, i()}; + } + break; + } + default: break; + } + } + }); + return result; + }(); + + if constexpr (not allows_duplicates) { + // If the key is already in the container, return false + if (group.any(state == detail::equal_result::EQUAL)) { return false; } + } + + auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE); + if (group_contains_available) { + auto const src_lane = __ffs(group_contains_available) - 1; + auto status = insert_result::CONTINUE; + if (group.thread_rank() == src_lane) { + if constexpr (SupportsErase) { + status = attempt_insert(this->get_slot_ptr(*probing_iter, intra_bucket_index), + bucket_slots[intra_bucket_index], + val); + } else { + status = attempt_insert(this->get_slot_ptr(*probing_iter, intra_bucket_index), + this->empty_slot_sentinel(), + val); + } + } + + switch (group.shfl(status, src_lane)) { + case insert_result::SUCCESS: return true; + case insert_result::DUPLICATE: { + if constexpr (allows_duplicates) { + [[fallthrough]]; + } else { + return false; + } + } + default: continue; + } + } else { + // Robin Hood displacement: no match, no empty slot in this bucket. Displace the first + // resident in probe (lane) order that is richer than the in-flight pair, adopt it, and + // re-probe THIS bucket -- the victim may belong in another slot of it. The within-bucket + // linear probe (combined bucket+slot distance) is identical to the scalar path; the + // `slot_distance` term cancels, so the test is again `resident distance < probe_step`. + cuda::std::int32_t displace_idx = -1; + size_type evicted_age = 0; + cuda::static_for([&] __device__(auto i) { + if (displace_idx < 0) { + // `robin_hood_age` so a tombstone uses its payload-stored age: it is displaced (i.e. + // consumed) exactly when richer than the in-flight pair, like any other resident. + auto const age = + this->robin_hood_age(bucket_slots[i()], static_cast(*probing_iter + i())); + if (age < probe_step) { + displace_idx = i(); + evicted_age = age; + } + } + }); + + auto const group_displaceable = group.ballot(displace_idx >= 0); + if (group_displaceable) { + auto const src_lane = __ffs(group_displaceable) - 1; + auto status = insert_result::CONTINUE; + // Only `src_lane` reads `evicted` meaningfully; other lanes just need a valid value to + // feed the broadcast `shfl` below, so seed it with the empty-slot sentinel. + value_type evicted = this->empty_slot_sentinel(); + if (group.thread_rank() == src_lane) { + evicted = bucket_slots[displace_idx]; + status = attempt_insert(this->get_slot_ptr(*probing_iter, displace_idx), evicted, val); + } + if (group.shfl(status, src_lane) == insert_result::SUCCESS) { + // Consuming a tombstone reuses its freed slot -- nothing to carry, so we are done. + if (group.shfl(this->is_erased(evicted), src_lane)) { return true; } + // Broadcast the evicted pair and its probe distance from the winning lane, and adopt + // it on every lane (all lanes need the new in-flight pair for the next scan). + auto const new_key = group.shfl(this->extract_key(evicted), src_lane); + auto const new_age = group.shfl(evicted_age, src_lane); + value_type evicted_slot; + if constexpr (has_payload) { + auto const new_payload = group.shfl(this->extract_payload(evicted), src_lane); + evicted_slot = value_type{new_key, new_payload}; + } else { + evicted_slot = new_key; + } + val = cuda::std::bit_cast(evicted_slot); + key = this->extract_key(val); + probe_step = new_age; + } + continue; // success: re-probe this bucket with the victim; lost CAS: re-read it + } + // No displaceable resident: fall through to the shared advance below. + ++probe_step; + ++probing_iter; + if (*probing_iter == init_idx) { return false; } + } + } + } + + /** + * @brief Inserts the given element into the container. + * + * @note This API returns a pair consisting of an iterator to the inserted element (or to the + * element that prevented the insertion) and a `bool` denoting whether the insertion took place or + * not. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param value The element to insert + * + * @return a pair consisting of an iterator to the element and a bool indicating whether the + * insertion is successful or not. + */ + template + __device__ cuda::std::pair insert_and_find(Value value) noexcept + { + static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); +#if __CUDA_ARCH__ < 700 + // Spinning to ensure that the write to the value part took place requires + // independent thread scheduling introduced with the Volta architecture. + static_assert(sizeof(value_type) <= 8, + "insert_and_find is not supported for slot types larger than 8 bytes on " + "pre-Volta GPUs."); +#endif + + auto val = this->heterogeneous_value(value); + auto key = this->extract_key(val); + auto probing_iter = + probing_scheme_.template make_iterator(key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type probe_step = 0; + // Robin Hood may displace the original key before the chain ends; remember the slot it landed + // in so we return an iterator to it (not to a later victim's slot). + value_type* placed_ptr = nullptr; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + bool retry = false; + + for (auto i = 0; i < bucket_size; ++i) { + auto const eq_res = this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i])); + auto* slot_ptr = this->get_slot_ptr(*probing_iter, i); + + // If the key is already in the container, return false + if (eq_res == detail::equal_result::EQUAL) { + this->maybe_wait_for_payload(slot_ptr); + return {iterator{slot_ptr}, false}; + } + // Robin Hood claims only a true empty here; a tombstone is handled as a resident by the + // displacement test below (see `insert`). + if (eq_res == detail::equal_result::AVAILABLE and not this->is_erased(bucket_slots[i])) { + switch (this->attempt_insert_stable(slot_ptr, bucket_slots[i], val)) { + case insert_result::SUCCESS: { + // The in-flight pair is placed in an empty slot, ending any displacement chain. The + // iterator to return is the original key's slot (captured on its first placement). + auto* result_ptr = slot_ptr; + if (placed_ptr != nullptr) { result_ptr = placed_ptr; } + this->maybe_wait_for_payload(result_ptr); + return {iterator{result_ptr}, true}; + } + case insert_result::DUPLICATE: { + this->maybe_wait_for_payload(slot_ptr); + return {iterator{slot_ptr}, false}; + } + case insert_result::CONTINUE: { + retry = true; + break; + } + } + if (retry) { break; } + } + + // Robin Hood swap test (see `insert` for the full rationale). A tombstone is a resident too + // (age from its payload); picking one up consumes it -- the in-flight pair lands there and + // we are done. + if (eq_res == detail::equal_result::UNEQUAL or this->is_erased(bucket_slots[i])) { + auto const evicted_age = + this->robin_hood_age(bucket_slots[i], static_cast(*probing_iter + i)); + if (evicted_age < probe_step) { + if (this->attempt_insert(slot_ptr, bucket_slots[i], val) == insert_result::SUCCESS) { + if (this->is_erased(bucket_slots[i])) { + // Consumed a tombstone: the in-flight pair is placed here; return the original + // key's slot (this one if it was never displaced). + auto* result_ptr = (placed_ptr != nullptr) ? placed_ptr : slot_ptr; + this->maybe_wait_for_payload(result_ptr); + return {iterator{result_ptr}, true}; + } + if (placed_ptr == nullptr) { placed_ptr = slot_ptr; } // original key's slot + val = cuda::std::bit_cast(bucket_slots[i]); + key = this->extract_key(val); + probe_step = evicted_age; + } + retry = true; + break; + } + } + } + + if (retry) { continue; } + ++probe_step; + ++probing_iter; + if (*probing_iter == init_idx) { return {this->end(), false}; } + }; + } + + /** + * @brief Inserts the given element into the container. + * + * @note This API returns a pair consisting of an iterator to the inserted element (or to the + * element that prevented the insertion) and a `bool` denoting whether the insertion took place or + * not. + * + * @tparam Value Input type which is convertible to 'value_type' + * @tparam ParentCG Type of parent Cooperative Group + * + * @param group The Cooperative Group used to perform group insert_and_find + * @param value The element to insert + * + * @return a pair consisting of an iterator to the element and a bool indicating whether the + * insertion is successful or not. + */ + template + __device__ cuda::std::pair insert_and_find( + cooperative_groups::thread_block_tile group, Value value) noexcept + { +#if __CUDA_ARCH__ < 700 + // Spinning to ensure that the write to the value part took place requires + // independent thread scheduling introduced with the Volta architecture. + static_assert(sizeof(value_type) <= 8, + "insert_and_find is not supported for slot types larger than 8 bytes on " + "pre-Volta GPUs."); +#endif + + auto val = this->heterogeneous_value(value); + auto key = this->extract_key(val); + auto probing_iter = + probing_scheme_.template make_iterator(group, key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type probe_step = 0; + // Robin Hood may displace the original key before the chain ends; remember (broadcast) the slot + // it first landed in so we return an iterator to it. 0 means "not yet placed". + intptr_t placed_ptr = 0; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + auto const [state, intra_bucket_index] = [&]() { + bucket_probing_results result{detail::equal_result::UNEQUAL, -1}; + cuda::static_for([&] __device__(auto i) { + if (result.state_ == detail::equal_result::UNEQUAL) { + auto res = this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i()])); + // Robin Hood: a tombstone is a resident handled by the displacement scan below, not + // AVAILABLE, so leave it UNEQUAL here. + if (res == detail::equal_result::AVAILABLE and this->is_erased(bucket_slots[i()])) { + res = detail::equal_result::UNEQUAL; + } + if (res != detail::equal_result::UNEQUAL) { result = bucket_probing_results{res, i()}; } + } + }); + return result; + }(); + + auto* slot_ptr = this->get_slot_ptr(*probing_iter, intra_bucket_index); + + // If the key is already in the container, return false + auto const group_finds_equal = group.ballot(state == detail::equal_result::EQUAL); + if (group_finds_equal) { + auto const src_lane = __ffs(group_finds_equal) - 1; + auto const res = group.shfl(reinterpret_cast(slot_ptr), src_lane); + if (group.thread_rank() == src_lane) { this->maybe_wait_for_payload(slot_ptr); } + group.sync(); + return {iterator{reinterpret_cast(res)}, false}; + } + + auto const group_contains_available = group.ballot(state == detail::equal_result::AVAILABLE); + if (group_contains_available) { + auto const src_lane = __ffs(group_contains_available) - 1; + auto const res = group.shfl(reinterpret_cast(slot_ptr), src_lane); + auto const status = [&, target_idx = intra_bucket_index]() { + if (group.thread_rank() != src_lane) { return insert_result::CONTINUE; } + return this->attempt_insert_stable(slot_ptr, bucket_slots[target_idx], val); + }(); + + switch (group.shfl(status, src_lane)) { + case insert_result::SUCCESS: { + // The in-flight pair is placed in an empty slot, ending any displacement chain. Return + // the original key's slot (the first placement) if it was displaced earlier. + auto result = res; + if (placed_ptr != 0) { result = placed_ptr; } + if (group.thread_rank() == src_lane) { this->maybe_wait_for_payload(slot_ptr); } + group.sync(); + return {iterator{reinterpret_cast(result)}, true}; + } + case insert_result::DUPLICATE: { + if (group.thread_rank() == src_lane) { this->maybe_wait_for_payload(slot_ptr); } + group.sync(); + return {iterator{reinterpret_cast(res)}, false}; + } + default: continue; + } + } else { + // Robin Hood displacement (see CG `insert` for the full rationale). + cuda::std::int32_t displace_idx = -1; + size_type evicted_age = 0; + cuda::static_for([&] __device__(auto i) { + if (displace_idx < 0) { + // `robin_hood_age` so a tombstone uses its payload-stored age: it is displaced (i.e. + // consumed) exactly when richer than the in-flight pair, like any other resident. + auto const age = + this->robin_hood_age(bucket_slots[i()], static_cast(*probing_iter + i())); + if (age < probe_step) { + displace_idx = i(); + evicted_age = age; + } + } + }); + + auto const group_displaceable = group.ballot(displace_idx >= 0); + if (group_displaceable) { + auto const src_lane = __ffs(group_displaceable) - 1; + auto status = insert_result::CONTINUE; + value_type evicted = this->empty_slot_sentinel(); + intptr_t displaced = 0; + if (group.thread_rank() == src_lane) { + auto* dptr = this->get_slot_ptr(*probing_iter, displace_idx); + evicted = bucket_slots[displace_idx]; + status = attempt_insert(dptr, evicted, val); + displaced = reinterpret_cast(dptr); + } + if (group.shfl(status, src_lane) == insert_result::SUCCESS) { + if (placed_ptr == 0) { placed_ptr = group.shfl(displaced, src_lane); } + // Consumed a tombstone: the in-flight pair is placed in its slot; we are done. Return + // the original key's slot (`placed_ptr`, which is this slot if it was never + // displaced). + if (group.shfl(this->is_erased(evicted), src_lane)) { + if (group.thread_rank() == src_lane) { + this->maybe_wait_for_payload(reinterpret_cast(displaced)); + } + group.sync(); + return {iterator{reinterpret_cast(placed_ptr)}, true}; + } + auto const new_key = group.shfl(this->extract_key(evicted), src_lane); + auto const new_age = group.shfl(evicted_age, src_lane); + value_type evicted_slot; + if constexpr (has_payload) { + auto const new_payload = group.shfl(this->extract_payload(evicted), src_lane); + evicted_slot = value_type{new_key, new_payload}; + } else { + evicted_slot = new_key; + } + val = cuda::std::bit_cast(evicted_slot); + key = this->extract_key(val); + probe_step = new_age; + } + continue; + } + ++probe_step; + ++probing_iter; + if (*probing_iter == init_idx) { return {this->end(), false}; } + } + } + } + + /** + * @brief Erases an element. + * + * @tparam ProbeKey Input type which is convertible to 'key_type' + * + * @param key The element to erase + * + * @return True if the given element is successfully erased + */ + template + __device__ bool erase(ProbeKey key) noexcept + { + static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + + auto probing_iter = + probing_scheme_.template make_iterator(key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + for (auto& slot_content : bucket_slots) { + auto const eq_res = + this->predicate_.template operator()(key, this->extract_key(slot_content)); + + // Key doesn't exist, return false + if (eq_res == detail::equal_result::EMPTY) { return false; } + // Key exists, return true if successfully deleted + if (eq_res == detail::equal_result::EQUAL) { + auto const intra_bucket_index = cuda::std::distance(bucket_slots.begin(), &slot_content); + // Robin Hood records the erased key's age in the tombstone payload (1a); other schemes + // use the plain erased sentinel. + value_type erased = this->robin_hood_erased_sentinel( + slot_content, static_cast(*probing_iter + intra_bucket_index)); + switch (attempt_insert_stable( + this->get_slot_ptr(*probing_iter, intra_bucket_index), slot_content, erased)) { + case insert_result::SUCCESS: return true; + case insert_result::DUPLICATE: return false; + default: continue; + } + } + } + ++probing_iter; + if (*probing_iter == init_idx) { return false; } + } + } + + /** + * @brief Erases an element. + * + * @tparam ProbeKey Input type which is convertible to 'key_type' + * @tparam ParentCG Type of parent Cooperative Group + * + * @param group The Cooperative Group used to perform group erase + * @param key The element to erase + * + * @return True if the given element is successfully erased + */ + template + __device__ bool erase(cooperative_groups::thread_block_tile group, + ProbeKey key) noexcept + { + auto probing_iter = + probing_scheme_.template make_iterator(group, key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + auto const [state, intra_bucket_index] = [&]() { + bucket_probing_results result{detail::equal_result::UNEQUAL, -1}; + cuda::static_for([&] __device__(auto i) { + if (result.state_ == detail::equal_result::UNEQUAL) { + auto res = this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i()])); + if (res != detail::equal_result::UNEQUAL) { result = bucket_probing_results{res, i()}; } + } + }); + return result; + }(); + + auto const group_contains_equal = group.ballot(state == detail::equal_result::EQUAL); + if (group_contains_equal) { + auto const src_lane = __ffs(group_contains_equal) - 1; + auto status = insert_result::CONTINUE; + if (group.thread_rank() == src_lane) { + // Robin Hood records the erased key's age in the tombstone payload (1a); other schemes + // use the plain erased sentinel. + value_type erased = this->robin_hood_erased_sentinel( + bucket_slots[intra_bucket_index], + static_cast(*probing_iter + intra_bucket_index)); + status = attempt_insert_stable(this->get_slot_ptr(*probing_iter, intra_bucket_index), + bucket_slots[intra_bucket_index], + erased); + } + + switch (group.shfl(status, src_lane)) { + case insert_result::SUCCESS: return true; + case insert_result::DUPLICATE: return false; + default: continue; + } + } + + // Key doesn't exist, return false + if (group.any(state == detail::equal_result::EMPTY)) { return false; } + + ++probing_iter; + if (*probing_iter == init_idx) { return false; } + } + } + + /** + * @brief Indicates whether the probe key `key` was inserted into the container. + * + * @note If the probe key `key` was inserted into the container, returns true. Otherwise, returns + * false. + * + * @tparam ProbeKey Probe key type + * + * @param key The key to search for + * + * @return A boolean indicating whether the probe key is present + */ + template + [[nodiscard]] __device__ bool contains(ProbeKey key) const noexcept + { + static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + auto probing_iter = + probing_scheme_.template make_iterator(key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type probe_step = 0; + + while (true) { + // TODO atomic_ref::load if insert operator is present + auto const bucket_slots = storage_ref_[*probing_iter]; + + for (auto i = 0; i < bucket_size; ++i) { + switch (this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i]))) { + case detail::equal_result::UNEQUAL: continue; + case detail::equal_result::EMPTY: return false; + case detail::equal_result::EQUAL: return true; + } + } + // Robin Hood: a resident richer than us proves the key is absent. + if (this->robin_hood_proves_absent(bucket_slots, *probing_iter, probe_step)) { return false; } + ++probe_step; + ++probing_iter; + if (*probing_iter == init_idx) { return false; } + } + } + + /** + * @brief Indicates whether the probe key `key` was inserted into the container. + * + * @note If the probe key `key` was inserted into the container, returns true. Otherwise, returns + * false. + * + * @tparam ProbeKey Probe key type + * @tparam ParentCG Type of parent Cooperative Group + * + * @param group The Cooperative Group used to perform group contains + * @param key The key to search for + * + * @return A boolean indicating whether the probe key is present + */ + template + [[nodiscard]] __device__ bool contains( + cooperative_groups::thread_block_tile group, ProbeKey key) const noexcept + { + auto probing_iter = + probing_scheme_.template make_iterator(group, key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type probe_step = 0; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + auto const state = [&]() { + auto res = detail::equal_result::UNEQUAL; + for (auto i = 0; i < bucket_size; ++i) { + res = this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i])); + if (res != detail::equal_result::UNEQUAL) { return res; } + } + return res; + }(); + + if (group.any(state == detail::equal_result::EQUAL)) { return true; } + if (group.any(state == detail::equal_result::EMPTY)) { return false; } + + // Robin Hood: a resident richer than us (in any lane's bucket) proves the key is absent. + if (group.any(this->robin_hood_proves_absent(bucket_slots, *probing_iter, probe_step))) { + return false; + } + ++probe_step; + ++probing_iter; + if (*probing_iter == init_idx) { return false; } + } + } + + /** + * @brief Finds an element in the container with key equivalent to the probe key. + * + * @note Returns a un-incrementable input iterator to the element whose key is equivalent to + * `key`. If no such element exists, returns `end()`. + * + * @tparam ProbeKey Probe key type + * + * @param key The key to search for + * + * @return An iterator to the position at which the equivalent key is stored + */ + template + [[nodiscard]] __device__ iterator find(ProbeKey key) const noexcept + { + static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + auto probing_iter = + probing_scheme_.template make_iterator(key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type probe_step = 0; + + while (true) { + // TODO atomic_ref::load if insert operator is present + auto const bucket_slots = storage_ref_[*probing_iter]; + + for (auto i = 0; i < bucket_size; ++i) { + switch (this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i]))) { + case detail::equal_result::EMPTY: { + return this->end(); + } + case detail::equal_result::EQUAL: { + return iterator{this->get_slot_ptr(*probing_iter, i)}; + } + default: continue; + } + } + // Robin Hood: a resident richer than us proves the key is absent. + if (this->robin_hood_proves_absent(bucket_slots, *probing_iter, probe_step)) { + return this->end(); + } + ++probe_step; + ++probing_iter; + if (*probing_iter == init_idx) { return this->end(); } + } + } + + /** + * @brief Finds an element in the container with key equivalent to the probe key. + * + * @note Returns a un-incrementable input iterator to the element whose key is equivalent to + * `key`. If no such element exists, returns `end()`. + * + * @tparam ProbeKey Probe key type + * @tparam ParentCG Type of parent Cooperative Group + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * + * @return An iterator to the position at which the equivalent key is stored + */ + template + [[nodiscard]] __device__ iterator + find(cooperative_groups::thread_block_tile group, ProbeKey key) const noexcept + { + auto probing_iter = + probing_scheme_.template make_iterator(group, key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type probe_step = 0; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + + auto const [state, intra_bucket_index] = [&]() { + bucket_probing_results result{detail::equal_result::UNEQUAL, -1}; + cuda::static_for([&] __device__(auto i) { + if (result.state_ == detail::equal_result::UNEQUAL) { + auto res = this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i()])); + if (res != detail::equal_result::UNEQUAL) { result = bucket_probing_results{res, i()}; } + } + }); + return result; + }(); + + // Find a match for the probe key, thus return an iterator to the entry + auto const group_finds_match = group.ballot(state == detail::equal_result::EQUAL); + if (group_finds_match) { + auto const src_lane = __ffs(group_finds_match) - 1; + auto const res = group.shfl( + reinterpret_cast(this->get_slot_ptr(*probing_iter, intra_bucket_index)), + src_lane); + return iterator{reinterpret_cast(res)}; + } + + // Find an empty slot, meaning that the probe key isn't present in the container + if (group.any(state == detail::equal_result::EMPTY)) { return this->end(); } + + // Robin Hood: a resident richer than us (in any lane's bucket) proves the key is absent. + if (group.any(this->robin_hood_proves_absent(bucket_slots, *probing_iter, probe_step))) { + return this->end(); + } + ++probe_step; + ++probing_iter; + if (*probing_iter == init_idx) { return this->end(); } + } + } + + /** + * @brief Counts the occurrence of a given key contained in the container + * + * @tparam ProbeKey Probe key type + * + * @param key The key to count for + * + * @return Number of occurrences found by the current thread + */ + template + [[nodiscard]] __device__ size_type count(ProbeKey key) const noexcept + { + if constexpr (not allows_duplicates) { + return static_cast(this->contains(key)); + } else { + auto probing_iter = + probing_scheme_.template make_iterator(key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type count = 0; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + cuda::std::int32_t equals[bucket_size] = {0}; + bool empty_found = false; + + cuda::static_for([&] __device__(auto i) { + auto const result = predicate_.template operator()( + key, this->extract_key(bucket_slots[i()])); + equals[i()] = (result == detail::equal_result::EQUAL); + if (result == detail::equal_result::EMPTY) { empty_found = true; } + }); + + count += thrust::reduce(thrust::seq, equals, equals + bucket_size); + + if (empty_found) { return count; } + + ++probing_iter; + if (*probing_iter == init_idx) { return count; } + } + } + } + + /** + * @brief Counts the occurrence of a given key contained in the container + * + * @tparam ProbeKey Probe key type + * @tparam ParentCG Type of parent Cooperative Group + * + * @param group The Cooperative Group used to perform group count + * @param key The key to count for + * + * @return Number of occurrences found by the current thread + */ + template + [[nodiscard]] __device__ size_type + count(cooperative_groups::thread_block_tile group, ProbeKey key) const noexcept + { + auto probing_iter = + probing_scheme_.template make_iterator(group, key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + size_type count = 0; + + while (true) { + auto const bucket_slots = storage_ref_[*probing_iter]; + cuda::std::int32_t equals[bucket_size] = {0}; + bool empty_found = false; + + cuda::static_for([&] __device__(auto i) { + auto const result = + predicate_.template operator()(key, this->extract_key(bucket_slots[i()])); + equals[i()] = (result == detail::equal_result::EQUAL); + if (result == detail::equal_result::EMPTY) { empty_found = true; } + }); + + count += thrust::reduce(thrust::seq, equals, equals + bucket_size); + + if (group.any(empty_found)) { return count; } + + ++probing_iter; + if (*probing_iter == init_idx) { return count; } + } + } + + /** + * @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin, + * input_probe_end)`. + * + * If key `k = *(first + i)` exists in the container, copies `k` to `output_probe` and associated + * slot contents to `output_match`, respectively. The output order is unspecified. + * + * Behavior is undefined if the size of the output range exceeds the number of retrieved slots. + * Use `count()` to determine the size of the output range. + * + * @tparam BlockSize Size of the thread block this operation is executed in + * @tparam InputProbeIt Device accessible input iterator + * @tparam OutputProbeIt Device accessible input iterator whose `value_type` is + * convertible to the `InputProbeIt`'s `value_type` + * @tparam OutputMatchIt Device accessible input iterator whose `value_type` is + * convertible to the container's `value_type` + * @tparam AtomicCounter Integral atomic counter type that follows the same semantics as + * `cuda::(std::)atomic(_ref)` + * + * @param block Thread block this operation is executed in + * @param input_probe_begin Beginning of the input sequence of keys + * @param input_probe_end End of the input sequence of keys + * @param output_probe Beginning of the sequence of keys corresponding to matching elements in + * `output_match` + * @param output_match Beginning of the sequence of matching elements + * @param atomic_counter Atomic object of integral type that is used to count the + * number of output elements + */ + template + __device__ void retrieve(cooperative_groups::thread_block const& block, + InputProbeIt input_probe_begin, + InputProbeIt input_probe_end, + OutputProbeIt output_probe, + OutputMatchIt output_match, + AtomicCounter& atomic_counter) const + { + auto constexpr is_outer = false; + auto const n = cuco::detail::distance(input_probe_begin, input_probe_end); + auto const always_true_stencil = cuda::constant_iterator(true); + auto const identity_predicate = cuda::std::identity{}; + this->retrieve_impl(block, + input_probe_begin, + n, + always_true_stencil, + identity_predicate, + output_probe, + output_match, + atomic_counter); + } + + /** + * @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin, + * input_probe_end)`. + * + * If key `k = *(first + i)` exists in the container, copies `k` to `output_probe` and associated + * slot contents to `output_match`, respectively. The output order is unspecified. + * + * Behavior is undefined if the size of the output range exceeds the number of retrieved slots. + * Use `count()` to determine the size of the output range. + * + * If a key `k` has no matches in the container, then `{key, empty_slot_sentinel}` will be added + * to the output sequence. + * + * @tparam BlockSize Size of the thread block this operation is executed in + * @tparam InputProbeIt Device accessible input iterator + * @tparam OutputProbeIt Device accessible input iterator whose `value_type` is + * convertible to the `InputProbeIt`'s `value_type` + * @tparam OutputMatchIt Device accessible input iterator whose `value_type` is + * convertible to the container's `value_type` + * @tparam AtomicCounter Integral atomic counter type that follows the same semantics as + * `cuda::(std::)atomic(_ref)` + * + * @param block Thread block this operation is executed in + * @param input_probe_begin Beginning of the input sequence of keys + * @param input_probe_end End of the input sequence of keys + * @param output_probe Beginning of the sequence of keys corresponding to matching elements in + * `output_match` + * @param output_match Beginning of the sequence of matching elements + * @param atomic_counter Atomic object of integral type that is used to count the + * number of output elements + */ + template + __device__ void retrieve_outer(cooperative_groups::thread_block const& block, + InputProbeIt input_probe_begin, + InputProbeIt input_probe_end, + OutputProbeIt output_probe, + OutputMatchIt output_match, + AtomicCounter& atomic_counter) const + { + auto constexpr is_outer = true; + auto const n = cuco::detail::distance(input_probe_begin, input_probe_end); + auto const always_true_stencil = cuda::constant_iterator(true); + auto const identity_predicate = cuda::std::identity{}; + this->retrieve_impl(block, + input_probe_begin, + n, + always_true_stencil, + identity_predicate, + output_probe, + output_match, + atomic_counter); + } + + /** + * @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin, + * input_probe_end)` if `pred` of the corresponding stencil returns true. + * + * If key `k = *(first + i)` exists in the container and `pred( *(stencil + i) )` returns true, + * copies `k` to `output_probe` and associated slot contents to `output_match`, + * respectively. The output order is unspecified. + * + * Behavior is undefined if the size of the output range exceeds the number of retrieved slots. + * Use `count()` to determine the size of the output range. + * + * @tparam BlockSize Size of the thread block this operation is executed in + * @tparam InputProbeIt Device accessible input iterator + * @tparam StencilIt Device accessible random access iterator whose value_type is + * convertible to Predicate's argument type + * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` + * and argument type is convertible from `std::iterator_traits::value_type` + * @tparam OutputProbeIt Device accessible input iterator whose `value_type` is + * convertible to the `InputProbeIt`'s `value_type` + * @tparam OutputMatchIt Device accessible input iterator whose `value_type` is + * convertible to the container's `value_type` + * @tparam AtomicCounter Integral atomic counter type that follows the same semantics as + * `cuda::(std::)atomic(_ref)` + * + * @param block Thread block this operation is executed in + * @param input_probe_begin Beginning of the input sequence of keys + * @param input_probe_end End of the input sequence of keys + * @param stencil Beginning of the stencil sequence + * @param pred Predicate to test on every element in the range `[stencil, stencil + n)` + * @param output_probe Beginning of the sequence of keys corresponding to matching elements in + * `output_match` + * @param output_match Beginning of the sequence of matching elements + * @param atomic_counter Atomic object of integral type that is used to count the + * number of output elements + */ + template + __device__ void retrieve_if(cooperative_groups::thread_block const& block, + InputProbeIt input_probe_begin, + InputProbeIt input_probe_end, + StencilIt stencil, + Predicate pred, + OutputProbeIt output_probe, + OutputMatchIt output_match, + AtomicCounter& atomic_counter) const + { + auto constexpr is_outer = false; + auto const n = cuco::detail::distance(input_probe_begin, input_probe_end); + this->retrieve_impl( + block, input_probe_begin, n, stencil, pred, output_probe, output_match, atomic_counter); + } + + /** + * @brief Retrieves all the slots corresponding to all keys in the range `[input_probe_begin, + * input_probe_end)`. + * + * If key `k = *(first + i)` exists in the container, copies `k` to `output_probe` and associated + * slot contents to `output_match`, respectively. The output order is unspecified. + * + * Behavior is undefined if the size of the output range exceeds the number of retrieved slots. + * Use `count()` to determine the size of the output range. + * + * If `IsOuter == true` and a key `k` has no matches in the container, then `{key, + * empty_slot_sentinel}` will be added to the output sequence. + * + * @tparam IsOuter Flag indicating if an inner or outer retrieve operation should be performed + * @tparam BlockSize Size of the thread block this operation is executed in + * @tparam InputProbeIt Device accessible input iterator + * @tparam StencilIt Device accessible random access iterator whose value_type is + * convertible to Predicate's argument type + * @tparam Predicate Unary predicate callable whose return type must be convertible to `bool` + * and argument type is convertible from `std::iterator_traits::value_type` + * @tparam OutputProbeIt Device accessible input iterator whose `value_type` is + * convertible to the `InputProbeIt`'s `value_type` + * @tparam OutputMatchIt Device accessible input iterator whose `value_type` is + * convertible to the container's `value_type` + * @tparam AtomicCounter Integral atomic type that follows the same semantics as + * `cuda::(std::)atomic(_ref)` + * + * @param block Thread block this operation is executed in + * @param input_probe Beginning of the input sequence of keys + * @param n Number of input keys + * @param stencil Beginning of the stencil sequence + * @param pred Predicate to test on every element in the range `[stencil, stencil + n)` + * @param output_probe Beginning of the sequence of keys corresponding to matching elements in + * `output_match` + * @param output_match Beginning of the sequence of matching elements + * @param atomic_counter Atomic object of integral type that is used to count the + * number of output elements + */ + template + __device__ void retrieve_impl(cooperative_groups::thread_block const& block, + InputProbeIt input_probe, + cuco::detail::index_type n, + StencilIt stencil, + Predicate pred, + OutputProbeIt output_probe, + OutputMatchIt output_match, + AtomicCounter& atomic_counter) const + { + namespace cg = cooperative_groups; + + if (n == 0) { return; } + + using probe_type = typename cuda::std::iterator_traits::value_type; + + // tuning parameter + auto constexpr buffer_multiplier = 1; + static_assert(buffer_multiplier > 0); + + auto constexpr probing_tile_size = cg_size; + auto constexpr flushing_tile_size = cuco::detail::warp_size(); + static_assert(flushing_tile_size >= probing_tile_size); + + auto constexpr num_flushing_tiles = BlockSize / flushing_tile_size; + auto constexpr max_matches_per_step = flushing_tile_size * bucket_size; + auto constexpr buffer_size = buffer_multiplier * max_matches_per_step + flushing_tile_size; + + auto const flushing_tile = cg::tiled_partition(block); + auto const probing_tile = cg::tiled_partition(block); + + auto const flushing_tile_id = flushing_tile.meta_group_rank(); + auto const stride = probing_tile.meta_group_size(); + auto idx = probing_tile.meta_group_rank(); + + __shared__ cuco::pair buffers[num_flushing_tiles][buffer_size]; + __shared__ cuda::std::int32_t counters[num_flushing_tiles]; + + if (flushing_tile.thread_rank() == 0) { counters[flushing_tile_id] = 0; } + flushing_tile.sync(); + + auto flush_buffers = [&](auto tile) { + size_type offset = 0; + auto const count = counters[flushing_tile_id]; + auto const rank = tile.thread_rank(); + if (rank == 0) { offset = atomic_counter.fetch_add(count, cuda::memory_order_relaxed); } + offset = tile.shfl(offset, 0); + + // flush_buffers + for (auto i = rank; i < count; i += tile.size()) { + *(output_probe + offset + i) = buffers[flushing_tile_id][i].first; + *(output_match + offset + i) = buffers[flushing_tile_id][i].second; + } + }; + + while (flushing_tile.any(idx < n)) { + bool active_flag = idx < n and pred(*(stencil + idx)); + auto const active_flushing_tile = + cg::binary_partition(flushing_tile, active_flag); + + if (active_flag) { + // perform probing + // make sure the flushing_tile is converged at this point to get a coalesced load + auto const probe_key = *(input_probe + idx); + + auto probing_iter = probing_scheme_.template make_iterator( + probing_tile, probe_key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + + bool running = true; + [[maybe_unused]] bool found_match = false; + + bool equals[bucket_size]; + cuda::std::uint32_t exists[bucket_size]; + + while (active_flushing_tile.any(running)) { + if (running) { + // TODO atomic_ref::load if insert operator is present + auto const bucket_slots = this->storage_ref_[*probing_iter]; + + cuda::static_for([&] __device__(auto i) { + equals[i()] = false; + if (running) { + // inspect slot content + switch (this->predicate_.template operator()( + probe_key, this->extract_key(bucket_slots[i()]))) { + case detail::equal_result::EMPTY: { + running = false; + break; + } + case detail::equal_result::EQUAL: { + if constexpr (!AllowsDuplicates) { running = false; } + equals[i()] = true; + break; + } + default: { + break; + } + } + } + }); + + probing_tile.sync(); + running = probing_tile.all(running); + cuda::static_for( + [&](auto i) { exists[i()] = probing_tile.ballot(equals[i()]); }); + + // Fill the buffer if any matching keys are found + auto const lane_id = probing_tile.thread_rank(); + if (thrust::any_of(thrust::seq, exists, exists + bucket_size, cuda::std::identity{})) { + if constexpr (IsOuter) { found_match = true; } + + cuda::std::int32_t num_matches[bucket_size]; + + cuda::static_for( + [&](auto i) { num_matches[i()] = __popc(exists[i()]); }); + + cuda::std::int32_t output_idx; + if (lane_id == 0) { + auto const total_matches = + thrust::reduce(thrust::seq, num_matches, num_matches + bucket_size); + auto ref = cuda::atomic_ref{ + counters[flushing_tile_id]}; + output_idx = ref.fetch_add(total_matches, cuda::memory_order_relaxed); + } + output_idx = probing_tile.shfl(output_idx, 0); + + cuda::std::int32_t matches_offset = 0; + cuda::static_for([&] __device__(auto i) { + if (equals[i()]) { + auto const lane_offset = + detail::count_least_significant_bits(exists[i()], lane_id); + buffers[flushing_tile_id][output_idx + matches_offset + lane_offset] = { + probe_key, bucket_slots[i()]}; + } + matches_offset += num_matches[i()]; + }); + } + // Special handling for outer cases where no match is found + if constexpr (IsOuter) { + if (!running) { + if (!found_match and lane_id == 0) { + auto ref = cuda::atomic_ref{ + counters[flushing_tile_id]}; + auto const output_idx = ref.fetch_add(1, cuda::memory_order_relaxed); + buffers[flushing_tile_id][output_idx] = {probe_key, this->empty_slot_sentinel()}; + } + } + } + } // if running + + active_flushing_tile.sync(); + // if the buffer has not enough empty slots for the next iteration + if (counters[flushing_tile_id] > (buffer_size - max_matches_per_step)) { + flush_buffers(active_flushing_tile); + active_flushing_tile.sync(); + + // reset buffer counter + if (active_flushing_tile.thread_rank() == 0) { counters[flushing_tile_id] = 0; } + active_flushing_tile.sync(); + } + + // onto the next probing bucket + ++probing_iter; + if (*probing_iter == init_idx) { running = false; } + } // while running + } // if active_flag + + // onto the next key + idx += stride; + } + + flushing_tile.sync(); + // entire flusing_tile has finished; flush remaining elements + if (counters[flushing_tile_id] > 0) { flush_buffers(flushing_tile); } + } + + /** + * @brief For a given key, applies the function object `callback_op` to the copy of all + * corresponding matches found in the container. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @tparam ProbeKey Probe key type + * @tparam CallbackOp Type of unary callback function object + * + * @param key The key to search for + * @param callback_op Function to apply to every matched slot + */ + template + __device__ void for_each(ProbeKey key, CallbackOp&& callback_op) const noexcept + { + static_assert(cg_size == 1, "Non-CG operation is incompatible with the current probing scheme"); + auto probing_iter = + probing_scheme_.template make_iterator(key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + + while (true) { + // TODO atomic_ref::load if insert operator is present + auto const bucket_slots = this->storage_ref_[*probing_iter]; + + bool should_return = false; + cuda::static_for([&] __device__(auto i) { + if (!should_return) { + switch (this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i()]))) { + case detail::equal_result::EMPTY: { + should_return = true; + break; + } + case detail::equal_result::EQUAL: { + callback_op(bucket_slots[i()]); + break; + } + default: break; + } + } + }); + if (should_return) { return; } + ++probing_iter; + if (*probing_iter == init_idx) { return; } + } + } + + /** + * @brief For a given key, applies the function object `callback_op` to the copy of all + * corresponding matches found in the container. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching element. If multiple elements are found within the same group, + * each thread with a match will call the callback with its associated element. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @tparam ProbeKey Probe key type + * @tparam CallbackOp Type of unary callback function object + * @tparam ParentCG Type of parent Cooperative Group + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback_op Function to apply to every matched slot + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile group, + ProbeKey key, + CallbackOp&& callback_op) const noexcept + { + auto probing_iter = + probing_scheme_.template make_iterator(group, key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + bool empty = false; + + while (true) { + // TODO atomic_ref::load if insert operator is present + auto const bucket_slots = this->storage_ref_[*probing_iter]; + + for (cuda::std::int32_t i = 0; i < bucket_size and !empty; ++i) { + switch (this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i]))) { + case detail::equal_result::EMPTY: { + empty = true; + continue; + } + case detail::equal_result::EQUAL: { + callback_op(bucket_slots[i]); + continue; + } + default: { + continue; + } + } + } + if (group.any(empty)) { return; } + + ++probing_iter; + if (*probing_iter == init_idx) { return; } + } + } + + /** + * @brief Applies the function object `callback_op` to the copy of every slot in the container + * with key equivalent to the probe key and can additionally perform work that requires + * synchronizing the Cooperative Group performing this operation. + * + * @note This function uses cooperative group semantics, meaning that any thread may call the + * callback if it finds a matching element. If multiple elements are found within the same group, + * each thread with a match will call the callback with its associated element. + * + * @note Synchronizing `group` within `callback_op` is undefined behavior. + * + * @note The return value of `callback_op`, if any, is ignored. + * + * @note The `sync_op` function can be used to perform work that requires synchronizing threads in + * `group` inbetween probing steps, where the number of probing steps performed between + * synchronization points is capped by `bucket_size * cg_size`. The functor will be called right + * after the current probing bucket has been traversed. + * + * @tparam ProbeKey Probe key type + * @tparam CallbackOp Type of unary callback function object + * @tparam SyncOp Type of function object which accepts the current `group` object + * @tparam ParentCG Type of parent Cooperative Group + * + * @param group The Cooperative Group used to perform this operation + * @param key The key to search for + * @param callback_op Function to apply to every matched slot + * @param sync_op Function that is allowed to synchronize `group` inbetween probing buckets + */ + template + __device__ void for_each(cooperative_groups::thread_block_tile group, + ProbeKey key, + CallbackOp&& callback_op, + SyncOp&& sync_op) const noexcept + { + auto probing_iter = + probing_scheme_.template make_iterator(group, key, storage_ref_.extent()); + auto const init_idx = *probing_iter; + bool empty = false; + + while (true) { + // TODO atomic_ref::load if insert operator is present + auto const bucket_slots = this->storage_ref_[*probing_iter]; + + for (cuda::std::int32_t i = 0; i < bucket_size and !empty; ++i) { + switch (this->predicate_.template operator()( + key, this->extract_key(bucket_slots[i]))) { + case detail::equal_result::EMPTY: { + empty = true; + continue; + } + case detail::equal_result::EQUAL: { + callback_op(bucket_slots[i]); + continue; + } + default: { + continue; + } + } + } + sync_op(group); + if (group.any(empty)) { return; } + + ++probing_iter; + if (*probing_iter == init_idx) { return; } + } + } + + /** + * @brief Gets a pointer to the slot at the given probing index and intra-bucket index. + * + * @param probing_idx The current probing index + * @param intra_bucket_idx The index within the bucket (0 for flat storage) + * @return Pointer to the slot + */ + __device__ value_type* get_slot_ptr(size_type probing_idx, + cuda::std::int32_t intra_bucket_idx) const noexcept + { + return storage_ref_.data() + probing_idx + intra_bucket_idx; + } + + /** + * @brief Determines whether the Robin Hood invariant proves the probe key absent at the current + * probe step. + * + * @note Only meaningful for Robin Hood probing. The key is proven absent when the bucket holds a + * resident that is "richer" than the probe key — i.e. whose own probe distance is smaller than + * the probe key's probe distance at the current step (`probe_step`). Such a resident would have + * been displaced on insertion if the probe key lived here, so the probe key cannot be present. + * + * @note Behavior is only well-defined when every slot in the bucket is occupied (the callers + * reach this check only after ruling out empty and matching slots), since probe distance is + * meaningless for an empty slot. + * + * @tparam BucketSlots Bucket slot array type + * + * @param bucket_slots The slots of the bucket currently being probed + * @param bucket_base The slot index of the first slot in the bucket + * @param probe_step The probe key's own probe distance at the current step + * + * @return True if some resident in the bucket is richer than the probe key + */ + template + [[nodiscard]] __device__ bool robin_hood_proves_absent(BucketSlots const& bucket_slots, + size_type bucket_base, + size_type probe_step) const noexcept + { + bool richer = false; + cuda::static_for([&](auto i) { + auto const resident_age = + this->robin_hood_age(bucket_slots[i()], static_cast(bucket_base + i())); + if (resident_age < probe_step) { richer = true; } + }); + return richer; + } + + /** + * @brief Whether `slot` holds a tombstone (erased marker). + * + * @note Returns false when erase is disabled (the erased and empty sentinels coincide, so no slot + * is a tombstone) -- this keeps the test correct even for empty slots. + * + * @param slot The slot to test + * + * @return True if `slot` is an erased tombstone + */ + [[nodiscard]] __device__ bool is_erased(value_type const& slot) const noexcept + { + return not cuco::detail::bitwise_compare(this->erased_key_sentinel(), + this->empty_key_sentinel()) and + cuco::detail::bitwise_compare(this->extract_key(slot), this->erased_key_sentinel()); + } + + /** + * @brief Robin Hood probe distance ("age") of an occupied slot. + * + * A live key's age is its `probe_distance`. A Robin Hood tombstone keeps the age of the key it + * replaced in its payload (the original key is gone and cannot be rehashed; see `erase`), so it + * is read back here -- a tombstone then participates in every Robin Hood comparison exactly like + * the resident it stood in for. + * + * @param slot The (occupied) slot + * @param slot_index The slot's index + * + * @return The slot's probe distance + */ + [[nodiscard]] __device__ size_type robin_hood_age(value_type const& slot, + size_type slot_index) const noexcept + { + if constexpr (has_payload) { + if (this->is_erased(slot)) { return static_cast(this->extract_payload(slot)); } + } + return robin_hood::probe_distance( + probing_scheme_, this->extract_key(slot), slot_index, storage_ref_.extent()); + } + + /** + * @brief The Robin Hood tombstone for erasing the live key currently in `slot` at `slot_index`. + * + * The erased key's age is stashed in the payload (1a) so the tombstone keeps its place in the + * Robin Hood ordering (the original key is gone and cannot be rehashed). Other probing schemes + * use the plain `erased_slot_sentinel()` and never call this. + * + * @param slot The slot's current (live) contents + * @param slot_index The slot's index + * + * @return The value to CAS into the slot to erase it + */ + [[nodiscard]] __device__ value_type + robin_hood_erased_sentinel(value_type const& slot, size_type slot_index) const noexcept + { + static_assert(has_payload, + "Robin Hood erase requires a mapped payload to store the tombstone age"); + auto const age = robin_hood::probe_distance( + probing_scheme_, this->extract_key(slot), slot_index, storage_ref_.extent()); + return cuco::pair{this->erased_key_sentinel(), + static_castempty_value_sentinel())>(age)}; + } + + /** + * @brief Extracts the key from a given value type. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param value The input value + * + * @return The key + */ + template + [[nodiscard]] __host__ __device__ constexpr auto extract_key(Value value) const noexcept + { + if constexpr (has_payload) { + return thrust::raw_reference_cast(value).first; + } else { + return thrust::raw_reference_cast(value); + } + } + + /** + * @brief Extracts the payload from a given value type. + * + * @note This function is only available if `this->has_payload == true` + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param value The input value + * + * @return The payload + */ + template > + [[nodiscard]] __host__ __device__ constexpr auto extract_payload(Value value) const noexcept + { + return thrust::raw_reference_cast(value).second; + } + + /** + * @brief Converts the given type to the container's native `value_type`. + * + * @tparam T Input type which is convertible to 'value_type' + * + * @param value The input value + * + * @return The converted object + */ + template + [[nodiscard]] __device__ constexpr value_type native_value(T value) const noexcept + { + if constexpr (has_payload) { + return {static_cast(this->extract_key(value)), this->extract_payload(value)}; + } else { + return static_cast(value); + } + } + + /** + * @brief Converts the given type to the container's native `value_type` while maintaining the + * heterogeneous key type. + * + * @tparam T Input type which is convertible to 'value_type' + * + * @param value The input value + * + * @return The converted object + */ + template + [[nodiscard]] __device__ constexpr auto heterogeneous_value(T value) const noexcept + { + if constexpr (has_payload and not cuda::std::is_same_v) { + using mapped_type = decltype(this->empty_value_sentinel()); + if constexpr (cuco::detail::is_cuda_std_pair_like::value) { + return cuco::pair{cuda::std::get<0>(value), + static_cast(cuda::std::get<1>(value))}; + } else { + // hail mary (convert using .first/.second members) + return cuco::pair{thrust::raw_reference_cast(value.first), + static_cast(value.second)}; + } + } else { + return thrust::raw_reference_cast(value); + } + } + + /** + * @brief Gets the sentinel used to represent an erased slot. + * + * @return The sentinel value used to represent an erased slot + */ + [[nodiscard]] __device__ constexpr value_type erased_slot_sentinel() const noexcept + { + if constexpr (has_payload) { + return cuco::pair{this->erased_key_sentinel(), this->empty_value_sentinel()}; + } else { + return this->erased_key_sentinel(); + } + } + + /** + * @brief Inserts the specified element with one single CAS operation. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert + * + * @return Result of this operation, i.e., success/continue/duplicate + */ + template + [[nodiscard]] __device__ constexpr insert_result packed_cas(value_type* address, + value_type expected, + Value desired) noexcept + { + using packed_type = cuco::detail::packed_t; + + auto* slot_ptr = reinterpret_cast(address); + auto* expected_ptr = reinterpret_cast(&expected); + auto* desired_ptr = reinterpret_cast(&desired); + + auto slot_ref = cuda::atomic_ref{*slot_ptr}; + + auto const success = + slot_ref.compare_exchange_strong(*expected_ptr, *desired_ptr, cuda::memory_order_relaxed); + + if (success) { + return insert_result::SUCCESS; + } else { + return this->predicate_.equal_to(this->extract_key(desired), this->extract_key(expected)) == + detail::equal_result::EQUAL + ? insert_result::DUPLICATE + : insert_result::CONTINUE; + } + } + + /** + * @brief Inserts the specified element with two back-to-back CAS operations. + * + * @note This CAS can be used exclusively for `cuco::op::insert` operations. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert + * + * @return Result of this operation, i.e., success/continue/duplicate + */ + template + [[nodiscard]] __device__ constexpr insert_result back_to_back_cas(value_type* address, + value_type expected, + Value desired) noexcept + { + using mapped_type = cuda::std::decay_tempty_value_sentinel())>; + + auto expected_key = expected.first; + auto expected_payload = this->empty_value_sentinel(); + + cuda::atomic_ref key_ref(address->first); + cuda::atomic_ref payload_ref(address->second); + + auto const key_cas_success = key_ref.compare_exchange_strong( + expected_key, static_cast(desired.first), cuda::memory_order_relaxed); + auto payload_cas_success = payload_ref.compare_exchange_strong( + expected_payload, desired.second, cuda::memory_order_relaxed); + + // if key success + if (key_cas_success) { + while (not payload_cas_success) { + payload_cas_success = + payload_ref.compare_exchange_strong(expected_payload = this->empty_value_sentinel(), + desired.second, + cuda::memory_order_relaxed); + } + return insert_result::SUCCESS; + } else if (payload_cas_success) { + // This is insert-specific, cannot for `erase` operations + payload_ref.store(this->empty_value_sentinel(), cuda::memory_order_relaxed); + } + + // Our key was already present in the slot, so our key is a duplicate + // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare + if (this->predicate_.equal_to(desired.first, expected_key) == detail::equal_result::EQUAL) { + return insert_result::DUPLICATE; + } + + return insert_result::CONTINUE; + } + + /** + * @brief Inserts the specified element with CAS-dependent write operations. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert + * + * @return Result of this operation, i.e., success/continue/duplicate + */ + template + [[nodiscard]] __device__ constexpr insert_result cas_dependent_write(value_type* address, + value_type expected, + Value desired) noexcept + { + using mapped_type = cuda::std::decay_tempty_value_sentinel())>; + + cuda::atomic_ref key_ref(address->first); + auto expected_key = expected.first; + auto const success = key_ref.compare_exchange_strong( + expected_key, static_cast(desired.first), cuda::memory_order_relaxed); + + // if key success + if (success) { + cuda::atomic_ref payload_ref(address->second); + payload_ref.store(desired.second, cuda::memory_order_relaxed); + return insert_result::SUCCESS; + } + + // Our key was already present in the slot, so our key is a duplicate + // Shouldn't use `predicate` operator directly since it includes a redundant bitwise compare + if (this->predicate_.equal_to(desired.first, expected_key) == detail::equal_result::EQUAL) { + return insert_result::DUPLICATE; + } + + return insert_result::CONTINUE; + } + + /** + * @brief Attempts to insert an element into a slot. + * + * @note Dispatches the correct implementation depending on the container + * type and presence of other operator mixins. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert + * + * @return Result of this operation, i.e., success/continue/duplicate + */ + template + [[nodiscard]] __device__ insert_result attempt_insert(value_type* address, + value_type expected, + Value desired) noexcept + { + if constexpr (sizeof(value_type) <= 8) { + return packed_cas(address, expected, desired); + } +#if (__CUDA_ARCH__ >= 900) + else if constexpr (cuco::detail::is_packable()) { + return packed_cas(address, expected, desired); + } +#endif + else if constexpr (has_payload) { +#if (__CUDA_ARCH__ < 700) + return cas_dependent_write(address, expected, desired); +#else + return back_to_back_cas(address, expected, desired); +#endif + } else { + static_assert(cuco::dependent_false, + "No valid atomic CAS path: 16-byte key in a key-only container must be " + "packable (have unique object representations) and target sm_90+."); + } + } + + /** + * @brief Attempts to insert an element into a slot. + * + * @note Dispatches the correct implementation depending on the container + * type and presence of other operator mixins. + * + * @note `stable` indicates that the payload will only be updated once from the sentinel value to + * the desired value, meaning there can be no ABA situations. + * + * @tparam Value Input type which is convertible to 'value_type' + * + * @param address Pointer to the slot in memory + * @param expected Element to compare against + * @param desired Element to insert + * + * @return Result of this operation, i.e., success/continue/duplicate + */ + template + [[nodiscard]] __device__ insert_result attempt_insert_stable(value_type* address, + value_type expected, + Value desired) noexcept + { + if constexpr (sizeof(value_type) <= 8) { + return packed_cas(address, expected, desired); + } +#if (__CUDA_ARCH__ >= 900) + else if constexpr (cuco::detail::is_packable()) { + return packed_cas(address, expected, desired); + } +#endif + else if constexpr (has_payload) { + return cas_dependent_write(address, expected, desired); + } else { + static_assert(cuco::dependent_false, + "No valid atomic CAS path: 16-byte key in a key-only container must be " + "packable (have unique object representations) and target sm_90+."); + } + } + + /** + * @brief Waits until the slot payload has been updated + * + * @note The function will return once the slot payload is no longer equal to the sentinel + * value. + * + * @tparam T Map slot type + * + * @param slot The target slot to check payload with + * @param sentinel The slot sentinel value + */ + template + __device__ void wait_for_payload(T& slot, T sentinel) const noexcept + { + auto ref = cuda::atomic_ref{slot}; + T current; + // TODO exponential backoff strategy + do { + current = ref.load(cuda::std::memory_order_relaxed); + } while (cuco::detail::bitwise_compare(current, sentinel)); + } + + /** + * @brief Conditionally spin-waits for the payload of a non-atomically inserted slot to become + * visible. + * + * For containers where the key and value are inserted by separate instructions + * (`cas_dependent_write` / `back_to_back_cas`), an observer thread may see the key before the + * payload. This helper spins until the payload is visible. For atomic single-CAS paths (slot + * size <= 8 bytes, or a packable slot on sm_90+ via `atom.cas.b128`), the payload is already + * visible and this is a no-op. + * + * @tparam SlotPtr Pointer-like type to a slot holding a `.second` payload member + * + * @param slot_ptr Pointer to the slot whose payload may need waiting on + */ + template + __device__ void maybe_wait_for_payload(SlotPtr slot_ptr) noexcept + { + if constexpr (has_payload and sizeof(value_type) > 8) { +#if (__CUDA_ARCH__ >= 900) + if constexpr (not cuco::detail::is_packable()) { + this->wait_for_payload(slot_ptr->second, this->empty_value_sentinel()); + } +#else + this->wait_for_payload(slot_ptr->second, this->empty_value_sentinel()); +#endif + } + } + + // TODO: Clean up the sentinel handling since it's duplicated in ref and equal wrapper + value_type empty_slot_sentinel_; ///< Sentinel value indicating an empty slot + detail::equal_wrapper + predicate_; ///< Key equality binary callable + probing_scheme_type probing_scheme_; ///< Probing scheme + storage_ref_type storage_ref_; ///< Slot storage ref +}; + +} // namespace robin_hood +} // namespace detail +} // namespace cuco diff --git a/include/cuco/static_map_ref.cuh b/include/cuco/static_map_ref.cuh index 660a7fcbd..1c9691d2b 100644 --- a/include/cuco/static_map_ref.cuh +++ b/include/cuco/static_map_ref.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -74,7 +75,12 @@ class static_map_ref static constexpr auto allows_duplicates = false; /// Implementation type - using impl_type = detail:: + // + // HARD-WIRE (experimental Robin Hood PR): static_map is routed through the Robin Hood engine + // instead of the generic open-addressing engine. This single line is what makes static_map use + // Robin Hood probing; downstream front-end work replaces it with a proper backend-selection + // mechanism. See robin_hood_refactor_plan.md. + using impl_type = detail::robin_hood:: open_addressing_ref_impl; public: