diff --git a/src/python/export_tensorwrapper.hpp b/src/python/export_tensorwrapper.hpp index 3d1de3d7..7417941b 100644 --- a/src/python/export_tensorwrapper.hpp +++ b/src/python/export_tensorwrapper.hpp @@ -15,6 +15,7 @@ */ #pragma once +#include #include namespace tensorwrapper { @@ -23,14 +24,16 @@ namespace tensorwrapper { // -- Type factorization used throughout the Python component // ----------------------------------------------------------------------------- +namespace py = pybind11; + /// Type of a C++ handle to a Python module -using py_module_type = pybind11::module_; +using py_module_type = py::module_; /// Type of a reference to an object of type py_module_type using py_module_reference = py_module_type&; /// Type of Python object binding for a C++ class of type @p T template -using py_class_type = pybind11::class_; +using py_class_type = py::class_; } // namespace tensorwrapper diff --git a/src/python/module.cpp b/src/python/module.cpp index 98c028c9..635aa789 100644 --- a/src/python/module.cpp +++ b/src/python/module.cpp @@ -15,7 +15,6 @@ */ #include "tensor/export_tensor.hpp" -#include namespace tensorwrapper { diff --git a/src/python/tensor/export_tensor.cpp b/src/python/tensor/export_tensor.cpp index 8e2d02da..0bb14ed0 100644 --- a/src/python/tensor/export_tensor.cpp +++ b/src/python/tensor/export_tensor.cpp @@ -16,22 +16,22 @@ #include "export_tensor.hpp" #include -#include #include #include #include namespace tensorwrapper { + namespace { template -auto get_desc_() -> decltype(pybind11::format_descriptor::format()) { +auto get_desc_() -> decltype(py::format_descriptor::format()) { if constexpr(std::is_same_v) - return pybind11::format_descriptor::format(); + return py::format_descriptor::format(); else if constexpr(std::is_same_v) - return pybind11::format_descriptor::format(); + return py::format_descriptor::format(); else if constexpr(std::is_same_v) - return pybind11::format_descriptor::format(); + return py::format_descriptor::format(); else throw std::runtime_error("Unsupported floating point type!"); } @@ -44,7 +44,7 @@ struct GetBufferDataKernel { m_rank(rank), m_psmooth_shape(&smooth_shape) {} template - pybind11::buffer_info operator()(std::span buffer) { + py::buffer_info operator()(std::span buffer) { using clean_type = std::decay_t; // We have only tested with doubles at the moment. @@ -66,7 +66,7 @@ struct GetBufferDataKernel { strides[rank_i] = stride_i * nbytes; } auto* ptr = const_cast(buffer.data()); - return pybind11::buffer_info(ptr, nbytes, desc, rank, shape, strides); + return py::buffer_info(ptr, nbytes, desc, rank, shape, strides); } size_type m_rank; @@ -74,8 +74,8 @@ struct GetBufferDataKernel { }; template -Tensor make_tensor_(pybind11::buffer_info& info) { - if(info.format != pybind11::format_descriptor::format()) +Tensor make_tensor_(py::buffer_info& info) { + if(info.format != py::format_descriptor::format()) throw std::runtime_error( "Incompatible format: expected a float array!"); @@ -107,9 +107,9 @@ auto make_buffer_info(buffer::Contiguous& buffer) { return buffer::visit_contiguous_buffer(kernel, buffer); } -Tensor make_tensor(pybind11::buffer b) { - pybind11::buffer_info info = b.request(); - if(info.format == pybind11::format_descriptor::format()) +Tensor make_tensor(py::buffer b) { + py::buffer_info info = b.request(); + if(info.format == py::format_descriptor::format()) return make_tensor_(info); else throw std::runtime_error( @@ -117,12 +117,12 @@ Tensor make_tensor(pybind11::buffer b) { } void export_tensor(py_module_reference m) { - py_class_type(m, "Tensor", pybind11::buffer_protocol()) - .def(pybind11::init<>()) - .def(pybind11::init([](pybind11::buffer b) { return make_tensor(b); })) + py_class_type(m, "Tensor", py::buffer_protocol()) + .def(py::init<>()) + .def(py::init([](py::buffer b) { return make_tensor(b); })) .def("rank", &Tensor::rank) - .def(pybind11::self == pybind11::self) - .def(pybind11::self != pybind11::self) + .def(py::self == py::self) + .def(py::self != py::self) .def("__str__", [](Tensor& self) { return self.to_string(); }) .def_buffer([](Tensor& t) { auto pbuffer = dynamic_cast(&t.buffer());