-
Notifications
You must be signed in to change notification settings - Fork 94
Expand file tree
/
Copy pathexample.cpp
More file actions
70 lines (55 loc) · 2.01 KB
/
example.cpp
File metadata and controls
70 lines (55 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <vector>
// -------------
// pure C++ code
// -------------
std::vector<double> length(const std::vector<double>& pos)
{
size_t N = pos.size() / 2;
std::vector<double> output(N*3);
for ( size_t i = 0 ; i < N ; ++i ) {
output[i*3+0] = pos[i*2+0];
output[i*3+1] = pos[i*2+1];
output[i*3+2] = std::pow(pos[i*2+0]*pos[i*2+1],.5);
}
return output;
}
// ----------------
// Python interface
// ----------------
namespace py = pybind11;
// wrap C++ function with NumPy array IO
py::array py_length(py::array_t<double, py::array::c_style | py::array::forcecast> array)
{
// check input dimensions
if ( array.ndim() != 2 )
throw std::runtime_error("Input should be 2-D NumPy array");
if ( array.shape()[1] != 2 )
throw std::runtime_error("Input should have size [N,2]");
// allocate std::vector (to pass to the C++ function)
std::vector<double> pos(array.size());
// copy py::array -> std::vector
std::memcpy(pos.data(),array.data(),array.size()*sizeof(double));
// call pure C++ function
std::vector<double> result = length(pos);
ssize_t ndim = 2;
std::vector<ssize_t> shape = { array.shape()[0] , 3 };
std::vector<ssize_t> strides = { sizeof(double)*3 , sizeof(double) };
// return 2-D NumPy array
return py::array(py::buffer_info(
result.data(), /* data as contiguous array */
sizeof(double), /* size of one scalar */
py::format_descriptor<double>::format(), /* data type */
ndim, /* number of dimensions */
shape, /* shape of the matrix */
strides /* strides for each axis */
));
}
// wrap as Python module
PYBIND11_MODULE(example,m)
{
m.doc() = "pybind11 example plugin";
m.def("length", &py_length, "Calculate the length of an array of vectors");
}