Skip to content

Commit 3cfcc0c

Browse files
committed
Add managed_memory_pointer that is compatible with STL.
The existing `cuda::pointer` uses a fancy reference that overloads `operator&`, and some STL implementations misbehave when that operator does not return the actual memory address of the object. Since universal_memory_resource allocates memory that works on both host and device, we need to be able to use these types with stl containers, such as std::vector, std::unique_ptr, etc. This patch adds a managed_pointer implementation that behaves like `cuda::pointer`, but returns a regular c++ reference, allowing the thrust universal allocator to work with STL containers.
1 parent d43f285 commit 3cfcc0c

File tree

8 files changed

+385
-12
lines changed

8 files changed

+385
-12
lines changed
+141
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
#include <thrust/detail/config.h>
2+
3+
#if THRUST_CPP_DIALECT >= 2011
4+
5+
# include <unittest/unittest.h>
6+
7+
# include <thrust/allocate_unique.h>
8+
# include <thrust/memory/detail/device_system_resource.h>
9+
# include <thrust/mr/allocator.h>
10+
# include <thrust/type_traits/is_contiguous_iterator.h>
11+
12+
# include <numeric>
13+
# include <vector>
14+
15+
namespace
16+
{
17+
18+
template <typename T>
19+
using allocator =
20+
thrust::mr::stateless_resource_allocator<T, thrust::universal_memory_resource>;
21+
22+
// The managed_memory_pointer class should be identified as a
23+
// contiguous_iterator
24+
THRUST_STATIC_ASSERT(
25+
thrust::is_contiguous_iterator<allocator<int>::pointer>::value);
26+
27+
template <typename T>
28+
struct some_object {
29+
some_object(T data)
30+
: m_data(data)
31+
{}
32+
33+
void setter(T data) { m_data = data; }
34+
T getter() const { return m_data; }
35+
36+
private:
37+
T m_data;
38+
};
39+
40+
} // namespace
41+
42+
template <typename T>
43+
void TestAllocateUnique()
44+
{
45+
// Simple test to ensure that pointers created with universal_memory_resource
46+
// can be dereferenced and used with STL code. This is necessary as some
47+
// STL implementations break when using fancy references that overload
48+
// `operator&`, so universal_memory_resource uses a special pointer type that
49+
// returns regular C++ references that can be safely used host-side.
50+
51+
// These operations fail to compile with fancy references:
52+
auto pRaw = thrust::allocate_unique<T>(allocator<T>{}, 42);
53+
auto pObj =
54+
thrust::allocate_unique<some_object<T> >(allocator<some_object<T> >{}, 42);
55+
56+
static_assert(
57+
std::is_same<decltype(pRaw.get()),
58+
thrust::system::cuda::detail::managed_memory_pointer<T> >::value,
59+
"Unexpected pointer returned from unique_ptr::get.");
60+
static_assert(
61+
std::is_same<decltype(pObj.get()),
62+
thrust::system::cuda::detail::managed_memory_pointer<
63+
some_object<T> > >::value,
64+
"Unexpected pointer returned from unique_ptr::get.");
65+
66+
ASSERT_EQUAL(*pRaw, T(42));
67+
ASSERT_EQUAL(*pRaw.get(), T(42));
68+
ASSERT_EQUAL(pObj->getter(), T(42));
69+
ASSERT_EQUAL((*pObj).getter(), T(42));
70+
ASSERT_EQUAL(pObj.get()->getter(), T(42));
71+
ASSERT_EQUAL((*pObj.get()).getter(), T(42));
72+
}
73+
DECLARE_GENERIC_UNITTEST(TestAllocateUnique);
74+
75+
template <typename T>
76+
void TestIterationRaw()
77+
{
78+
auto array = thrust::allocate_unique_n<T>(allocator<T>{}, 6, 42);
79+
80+
static_assert(
81+
std::is_same<decltype(array.get()),
82+
thrust::system::cuda::detail::managed_memory_pointer<T> >::value,
83+
"Unexpected pointer returned from unique_ptr::get.");
84+
85+
for (auto iter = array.get(), end = array.get() + 6; iter < end; ++iter)
86+
{
87+
ASSERT_EQUAL(*iter, T(42));
88+
ASSERT_EQUAL(*iter.get(), T(42));
89+
}
90+
}
91+
DECLARE_GENERIC_UNITTEST(TestIterationRaw);
92+
93+
template <typename T>
94+
void TestIterationObj()
95+
{
96+
auto array =
97+
thrust::allocate_unique_n<some_object<T> >(allocator<some_object<T> >{},
98+
6,
99+
42);
100+
101+
static_assert(
102+
std::is_same<decltype(array.get()),
103+
thrust::system::cuda::detail::managed_memory_pointer<
104+
some_object<T> > >::value,
105+
"Unexpected pointer returned from unique_ptr::get.");
106+
107+
for (auto iter = array.get(), end = array.get() + 6; iter < end; ++iter)
108+
{
109+
ASSERT_EQUAL(iter->getter(), T(42));
110+
ASSERT_EQUAL((*iter).getter(), T(42));
111+
ASSERT_EQUAL(iter.get()->getter(), T(42));
112+
ASSERT_EQUAL((*iter.get()).getter(), T(42));
113+
}
114+
}
115+
DECLARE_GENERIC_UNITTEST(TestIterationObj);
116+
117+
template <typename T>
118+
void TestStdVector()
119+
{
120+
// Verify that a std::vector using the universal allocator will work with
121+
// STL algorithms.
122+
std::vector<T, allocator<T> > v0;
123+
124+
static_assert(
125+
std::is_same<typename std::decay<decltype(v0)>::type::pointer,
126+
thrust::system::cuda::detail::managed_memory_pointer<
127+
T > >::value,
128+
"Unexpected pointer returned from unique_ptr::get.");
129+
130+
v0.resize(6);
131+
std::iota(v0.begin(), v0.end(), 0);
132+
ASSERT_EQUAL(v0[0], T(0));
133+
ASSERT_EQUAL(v0[1], T(1));
134+
ASSERT_EQUAL(v0[2], T(2));
135+
ASSERT_EQUAL(v0[3], T(3));
136+
ASSERT_EQUAL(v0[4], T(4));
137+
ASSERT_EQUAL(v0[5], T(5));
138+
}
139+
DECLARE_GENERIC_UNITTEST(TestStdVector);
140+
141+
#endif // C++11
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
CUDACC_FLAGS += -rdc=true

testing/vector.cu

+9-6
Original file line numberDiff line numberDiff line change
@@ -52,24 +52,27 @@ DECLARE_VECTOR_UNITTEST(TestVectorFrontBack);
5252
template <class Vector>
5353
void TestVectorData(void)
5454
{
55+
typedef typename Vector::pointer PointerT;
56+
typedef typename Vector::const_pointer PointerConstT;
57+
5558
Vector v(3);
5659
v[0] = 0; v[1] = 1; v[2] = 2;
5760

5861
ASSERT_EQUAL(0, *v.data());
5962
ASSERT_EQUAL(1, *(v.data() + 1));
6063
ASSERT_EQUAL(2, *(v.data() + 2));
61-
ASSERT_EQUAL(&v.front(), v.data());
62-
ASSERT_EQUAL(&*v.begin(), v.data());
63-
ASSERT_EQUAL(&v[0], v.data());
64+
ASSERT_EQUAL(PointerT(&v.front()), v.data());
65+
ASSERT_EQUAL(PointerT(&*v.begin()), v.data());
66+
ASSERT_EQUAL(PointerT(&v[0]), v.data());
6467

6568
const Vector &c_v = v;
6669

6770
ASSERT_EQUAL(0, *c_v.data());
6871
ASSERT_EQUAL(1, *(c_v.data() + 1));
6972
ASSERT_EQUAL(2, *(c_v.data() + 2));
70-
ASSERT_EQUAL(&c_v.front(), c_v.data());
71-
ASSERT_EQUAL(&*c_v.begin(), c_v.data());
72-
ASSERT_EQUAL(&c_v[0], c_v.data());
73+
ASSERT_EQUAL(PointerConstT(&c_v.front()), c_v.data());
74+
ASSERT_EQUAL(PointerConstT(&*c_v.begin()), c_v.data());
75+
ASSERT_EQUAL(PointerConstT(&c_v[0]), c_v.data());
7376
}
7477
DECLARE_VECTOR_UNITTEST(TestVectorData);
7578

thrust/detail/pointer.inl

+33-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <thrust/detail/config.h>
1818
#include <thrust/detail/pointer.h>
19+
#include <thrust/detail/type_traits.h>
1920

2021

2122
namespace thrust
@@ -109,14 +110,43 @@ template<typename Element, typename Tag, typename Reference, typename Derived>
109110
return static_cast<derived_type&>(*this);
110111
} // end pointer::operator=
111112

113+
namespace detail
114+
{
115+
116+
// Implementation for dereference() when Reference is Element&,
117+
// e.g. cuda's managed_memory_pointer
118+
template <typename Reference, typename Derived>
119+
__host__ __device__
120+
Reference pointer_dereference_impl(const Derived& ptr,
121+
thrust::detail::true_type /* is_cpp_ref */)
122+
{
123+
return *ptr.get();
124+
}
125+
126+
// Implementation for pointers with proxy references:
127+
template <typename Reference, typename Derived>
128+
__host__ __device__
129+
Reference pointer_dereference_impl(const Derived& ptr,
130+
thrust::detail::false_type /* is_cpp_ref */)
131+
{
132+
return Reference(ptr);
133+
}
134+
135+
} // namespace detail
112136

113137
template<typename Element, typename Tag, typename Reference, typename Derived>
114138
__host__ __device__
115139
typename pointer<Element,Tag,Reference,Derived>::super_t::reference
116-
pointer<Element,Tag,Reference,Derived>
117-
::dereference() const
140+
pointer<Element,Tag,Reference,Derived>
141+
::dereference() const
118142
{
119-
return typename super_t::reference(static_cast<const derived_type&>(*this));
143+
// Need to handle cpp refs and fancy refs differently:
144+
typedef typename super_t::reference RefT;
145+
typedef typename thrust::detail::is_reference<RefT>::type IsCppRef;
146+
147+
const derived_type& derivedPtr = static_cast<const derived_type&>(*this);
148+
149+
return detail::pointer_dereference_impl<RefT>(derivedPtr, IsCppRef());
120150
} // end pointer::dereference
121151

122152

thrust/detail/vector_base.inl

+2-2
Original file line numberDiff line numberDiff line change
@@ -540,15 +540,15 @@ template<typename T, typename Alloc>
540540
vector_base<T,Alloc>
541541
::data(void)
542542
{
543-
return &front();
543+
return pointer(&front());
544544
} // end vector_base::data()
545545

546546
template<typename T, typename Alloc>
547547
typename vector_base<T,Alloc>::const_pointer
548548
vector_base<T,Alloc>
549549
::data(void) const
550550
{
551-
return &front();
551+
return const_pointer(&front());
552552
} // end vector_base::data()
553553

554554
template<typename T, typename Alloc>

thrust/mr/allocator.h

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#include <limits>
2424

25+
#include <thrust/detail/config/exec_check_disable.h>
2526
#include <thrust/detail/type_traits/pointer_traits.h>
2627

2728
#include <thrust/mr/detail/config.h>
@@ -93,6 +94,7 @@ class allocator : private validator<MR>
9394
*
9495
* \returns the maximum value of \p std::size_t, divided by the size of \p T.
9596
*/
97+
__thrust_exec_check_disable__
9698
__host__ __device__
9799
size_type max_size() const
98100
{

0 commit comments

Comments
 (0)