diff --git a/CMakeLists.txt b/CMakeLists.txt index 583949c..38b8d72 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -321,6 +321,7 @@ execute_process( COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/scripts/generate_public_headers.py --output-dir ${CMAKE_CURRENT_SOURCE_DIR}/generated/include + --source-output ${CMAKE_CURRENT_SOURCE_DIR}/generated/src/runtime_dispatch.cc --devices ${INFINI_RT_PUBLIC_HEADER_DEVICES} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} RESULT_VARIABLE INFINI_RT_PUBLIC_HEADER_RESULT diff --git a/scripts/generate_public_headers.py b/scripts/generate_public_headers.py index dc49df8..71796c6 100644 --- a/scripts/generate_public_headers.py +++ b/scripts/generate_public_headers.py @@ -1,5 +1,7 @@ import argparse +import dataclasses import pathlib +import re _DEVICE_HEADERS = { @@ -45,6 +47,26 @@ ), } +_DEVICE_TYPES = { + "cpu": "Device::Type::kCpu", + "nvidia": "Device::Type::kNvidia", + "iluvatar": "Device::Type::kIluvatar", + "metax": "Device::Type::kMetax", + "moore": "Device::Type::kMoore", + "cambricon": "Device::Type::kCambricon", + "ascend": "Device::Type::kAscend", +} + +_RUNTIME_HEADERS = { + "cpu": "native/cpu/runtime_.h", + "nvidia": "native/cuda/nvidia/runtime_.h", + "iluvatar": "native/cuda/iluvatar/runtime_.h", + "metax": "native/cuda/metax/runtime_.h", + "moore": "native/cuda/moore/runtime_.h", + "cambricon": "native/cambricon/runtime_.h", + "ascend": "native/ascend/runtime_.h", +} + def _guard(path): token = "_".join(path.parts).replace(".", "_").upper() @@ -92,9 +114,230 @@ def _write_generated_header(include_root, devices): ) +@dataclasses.dataclass(frozen=True) +class _Param: + type: str + name: str + + +@dataclasses.dataclass(frozen=True) +class _Function: + return_type: str + name: str + params: tuple[_Param, ...] + + def signature(self): + return f"{self.return_type} {self.name}({self.params_decl()})" + + def params_decl(self): + return ", ".join(f"{param.type} {param.name}" for param in self.params) + + +def _parse_param(param): + param_type, param_name = param.strip().rsplit(" ", 1) + + return _Param(param_type, param_name) + + +def _parse_runtime_functions(runtime_header): + text = pathlib.Path(runtime_header).read_text() + return tuple( + _Function( + return_type, + name, + tuple(_parse_param(param) for param in params.split(", ") if param), + ) + for return_type, name, params in re.findall( + r"^(void) ([A-Z]\w*)\(([^()]*)\);$", text, re.MULTILINE + ) + ) + + +def _abort_statement(message): + return f""" assert(false && "{message}"); + std::abort();""" + + +def _dispatch_cases(devices, statements): + return "\n".join( + f""" case {_DEVICE_TYPES[device]}: {{ +{statements.replace("__DEVICE_TYPE__", _DEVICE_TYPES[device])} + return; + }}""" + for device in devices + ) + + +def _selector(function): + for param in function.params: + if param.type == "Device": + return f"{param.name}.type()" + if param.type == "Device::Type": + return param.name + + return "current_device.type()" + + +def _runtime_arg(param): + if param.type == "Device": + return f"{param.name}.index()" + if param.type == "Device::Type": + return None + if param.type == "MemcpyKind": + return f"RuntimeMemcpyKind<__DEVICE_TYPE__>({param.name})" + + return param.name + + +def _runtime_args(function): + args = (_runtime_arg(param) for param in function.params) + + return ", ".join(arg for arg in args if arg is not None) + + +def _preconditions(function): + required_pointer_names = { + "GetDevice": {"device"}, + "GetDeviceCount": {"count"}, + } + checks = [] + for param in function.params: + if param.type.endswith("**") or param.name in required_pointer_names.get( + function.name, set() + ): + checks.append(f" assert({param.name} != nullptr);") + + return "\n".join(checks) + + +def _post_dispatch(function): + if function.name == "SetDevice": + return "\n current_device = device;" + + return "" + + +def _runtime_call(function): + args = _runtime_args(function) + if args: + return f"Runtime<__DEVICE_TYPE__>::{function.name}({args})" + + return f"Runtime<__DEVICE_TYPE__>::{function.name}()" + + +def _write_get_device(function, devices): + device_param = function.params[0].name + cases = _dispatch_cases( + devices, + f""" int index = current_device.index(); + CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GetDevice(&index); }}); + current_device = Device{{current_device.type(), index}}; + *{device_param} = current_device;""", + ) + + return f"""void GetDevice(Device* {device_param}) {{ + assert({device_param} != nullptr); + + switch (current_device.type()) {{ +{cases} + default: +{_abort_statement("runtime device is not enabled")} + }} +}} +""" + + +def _write_dispatch_function(function, devices): + if function.name == "GetDevice": + return _write_get_device(function, devices) + + cases = _dispatch_cases( + devices, + f""" CheckCall([&] {{ return {_runtime_call(function)}; }});{_post_dispatch(function)}""", + ) + preconditions = _preconditions(function) + if preconditions: + preconditions = f"{preconditions}\n\n" + + return f"""{function.signature()} {{ +{preconditions} switch ({_selector(function)}) {{ +{cases} + default: +{_abort_statement("runtime device is not enabled")} + }} +}} +""" + + +def _write_runtime_dispatch(source_path, runtime_header, devices): + first_device_type = _DEVICE_TYPES[devices[0]] + includes = ['#include "runtime.h"'] + includes.extend(f'#include "{_RUNTIME_HEADERS[device]}"' for device in devices) + functions = _parse_runtime_functions(runtime_header) + dispatch_functions = "\n".join( + _write_dispatch_function(function, devices) for function in functions + ) + + source_path.parent.mkdir(parents=True, exist_ok=True) + source_path.write_text( + f"""#include +#include +#include +#include + +{chr(10).join(includes)} + +namespace infini::rt {{ +namespace {{ + +thread_local Device current_device{{{first_device_type}, 0}}; + +template +void CheckCall(Func&& func) {{ + using ReturnType = decltype(std::forward(func)()); + + if constexpr (std::is_void_v) {{ + std::forward(func)(); + }} else {{ + ReturnType status = std::forward(func)(); + if (status != ReturnType{{}}) {{ + assert(false && "runtime call failed"); + std::abort(); + }} + }} +}} + +template +auto RuntimeMemcpyKind(MemcpyKind kind) {{ + switch (kind) {{ + case MemcpyKind::kHostToHost: + return Runtime::MemcpyHostToHost; + case MemcpyKind::kHostToDevice: + return Runtime::MemcpyHostToDevice; + case MemcpyKind::kDeviceToHost: + return Runtime::MemcpyDeviceToHost; + case MemcpyKind::kDeviceToDevice: + return Runtime::MemcpyDeviceToDevice; + }} + + assert(false && "unsupported memcpy kind"); + std::abort(); + return Runtime::MemcpyHostToHost; +}} + +}} // namespace + +{dispatch_functions} +}} // namespace infini::rt +""" + ) + + def main(): parser = argparse.ArgumentParser() parser.add_argument("--output-dir", default="generated/include") + parser.add_argument("--source-output", default="generated/src/runtime_dispatch.cc") + parser.add_argument("--runtime-header", default="src/runtime.h") parser.add_argument("--devices", nargs="+", required=True) args = parser.parse_args() @@ -112,6 +355,9 @@ def main(): _write_wrapper(include_root, wrapper_device, header_name, target) _write_generated_header(include_root, devices) + _write_runtime_dispatch( + pathlib.Path(args.source_output), args.runtime_header, devices + ) if __name__ == "__main__": diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index f1be901..5b1e1c0 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -3,7 +3,9 @@ add_library(infinirt SHARED) include(GNUInstallDirs) file(GLOB BASE_SRCS CONFIGURE_DEPENDS "*.cc") -target_sources(infinirt PRIVATE ${BASE_SRCS}) +target_sources(infinirt PRIVATE + ${BASE_SRCS} + ${PROJECT_SOURCE_DIR}/generated/src/runtime_dispatch.cc) if(WITH_CPU) target_compile_definitions(infinirt PUBLIC WITH_CPU=1) diff --git a/src/native/ascend/runtime_.h b/src/native/ascend/runtime_.h index dce9533..8b33e54 100644 --- a/src/native/ascend/runtime_.h +++ b/src/native/ascend/runtime_.h @@ -1,6 +1,9 @@ #ifndef INFINI_RT_ASCEND_RUNTIME__H_ #define INFINI_RT_ASCEND_RUNTIME__H_ +#include +#include + // clang-format off #include "acl/acl.h" // clang-format on @@ -17,6 +20,20 @@ struct Runtime static constexpr Device::Type kDeviceType = Device::Type::kAscend; + static constexpr auto SetDevice = aclrtSetDevice; + + static constexpr auto GetDevice = aclrtGetDevice; + + static auto GetDeviceCount(int* count) { + assert(count != nullptr); + std::uint32_t device_count = 0; + auto status = aclrtGetDeviceCount(&device_count); + *count = static_cast(device_count); + return status; + } + + static constexpr auto DeviceSynchronize = aclrtSynchronizeDevice; + static constexpr auto Malloc = [](void** ptr, size_t size) { return aclrtMalloc(ptr, size, ACL_MEM_MALLOC_HUGE_FIRST); }; @@ -28,10 +45,14 @@ struct Runtime return aclrtMemcpy(dst, count, src, count, kind); }; + static constexpr auto MemcpyHostToHost = ACL_MEMCPY_HOST_TO_HOST; + static constexpr auto MemcpyHostToDevice = ACL_MEMCPY_HOST_TO_DEVICE; static constexpr auto MemcpyDeviceToHost = ACL_MEMCPY_DEVICE_TO_HOST; + static constexpr auto MemcpyDeviceToDevice = ACL_MEMCPY_DEVICE_TO_DEVICE; + static constexpr auto Memset = [](void* ptr, int value, size_t count) { return aclrtMemset(ptr, count, value, count); }; diff --git a/src/native/cambricon/runtime_.h b/src/native/cambricon/runtime_.h index 2b686f2..4db4920 100644 --- a/src/native/cambricon/runtime_.h +++ b/src/native/cambricon/runtime_.h @@ -3,6 +3,9 @@ #include +#include +#include + #include "native/cambricon/device_.h" #include "runtime.h" @@ -15,16 +18,37 @@ struct Runtime static constexpr Device::Type kDeviceType = Device::Type::kCambricon; + static constexpr auto SetDevice = cnrtSetDevice; + + static constexpr auto GetDevice = cnrtGetDevice; + + static auto GetDeviceCount(int* count) { + assert(count != nullptr); + unsigned int device_count = 0; + auto status = cnrtGetDeviceCount(&device_count); + *count = static_cast(device_count); + return status; + } + + static constexpr auto DeviceSynchronize = cnrtSyncDevice; + static constexpr auto Malloc = cnrtMalloc; static constexpr auto Free = cnrtFree; - static constexpr auto Memcpy = cnrtMemcpy; + static constexpr auto Memcpy = [](void* dst, const void* src, + std::size_t size, auto kind) { + return cnrtMemcpy(dst, const_cast(src), size, kind); + }; + + static constexpr auto MemcpyHostToHost = cnrtMemcpyHostToHost; static constexpr auto MemcpyHostToDevice = cnrtMemcpyHostToDev; static constexpr auto MemcpyDeviceToHost = cnrtMemcpyDevToHost; + static constexpr auto MemcpyDeviceToDevice = cnrtMemcpyDevToDev; + static constexpr auto Memset = cnrtMemset; }; diff --git a/src/native/cpu/runtime_.h b/src/native/cpu/runtime_.h index 29219b8..bf5a81c 100644 --- a/src/native/cpu/runtime_.h +++ b/src/native/cpu/runtime_.h @@ -1,6 +1,7 @@ #ifndef INFINI_RT_CPU_RUNTIME__H_ #define INFINI_RT_CPU_RUNTIME__H_ +#include #include #include @@ -12,6 +13,25 @@ template <> struct Runtime : RuntimeBase> { static constexpr Device::Type kDeviceType = Device::Type::kCpu; + static void SetDevice(int index) { + if (index != 0) { + assert(false && "CPU device index must be 0"); + std::abort(); + } + } + + static void GetDevice(int* index) { + assert(index != nullptr); + *index = 0; + } + + static void GetDeviceCount(int* count) { + assert(count != nullptr); + *count = 1; + } + + static void DeviceSynchronize() {} + static void Malloc(void** ptr, std::size_t size) { *ptr = std::malloc(size); } static void Free(void* ptr) { std::free(ptr); } @@ -20,11 +40,17 @@ struct Runtime : RuntimeBase> { std::memcpy(dst, src, size); } - static constexpr auto Memset = std::memset; + static void Memset(void* ptr, int value, std::size_t count) { + std::memset(ptr, value, count); + } + + static constexpr int MemcpyHostToHost = 0; static constexpr int MemcpyHostToDevice = 0; static constexpr int MemcpyDeviceToHost = 1; + + static constexpr int MemcpyDeviceToDevice = 0; }; static_assert(Runtime::Validate()); diff --git a/src/native/cuda/iluvatar/runtime_.h b/src/native/cuda/iluvatar/runtime_.h index eb79a33..8a1b649 100644 --- a/src/native/cuda/iluvatar/runtime_.h +++ b/src/native/cuda/iluvatar/runtime_.h @@ -19,6 +19,14 @@ struct Runtime static constexpr Device::Type kDeviceType = Device::Type::kIluvatar; + static constexpr auto SetDevice = cudaSetDevice; + + static constexpr auto GetDevice = cudaGetDevice; + + static constexpr auto GetDeviceCount = cudaGetDeviceCount; + + static constexpr auto DeviceSynchronize = cudaDeviceSynchronize; + static constexpr auto Malloc = [](auto&&... args) { return cudaMalloc(std::forward(args)...); }; @@ -27,10 +35,14 @@ struct Runtime static constexpr auto Free = cudaFree; + static constexpr auto MemcpyHostToHost = cudaMemcpyHostToHost; + static constexpr auto MemcpyHostToDevice = cudaMemcpyHostToDevice; static constexpr auto MemcpyDeviceToHost = cudaMemcpyDeviceToHost; + static constexpr auto MemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; + static constexpr auto Memset = cudaMemset; }; diff --git a/src/native/cuda/metax/runtime_.h b/src/native/cuda/metax/runtime_.h index a436e21..348d44b 100644 --- a/src/native/cuda/metax/runtime_.h +++ b/src/native/cuda/metax/runtime_.h @@ -15,16 +15,28 @@ struct Runtime static constexpr Device::Type kDeviceType = Device::Type::kMetax; + static constexpr auto SetDevice = mcSetDevice; + + static constexpr auto GetDevice = mcGetDevice; + + static constexpr auto GetDeviceCount = mcGetDeviceCount; + + static constexpr auto DeviceSynchronize = mcDeviceSynchronize; + static constexpr auto Malloc = mcMalloc; static constexpr auto Memcpy = mcMemcpy; static constexpr auto Free = mcFree; + static constexpr auto MemcpyHostToHost = mcMemcpyHostToHost; + static constexpr auto MemcpyHostToDevice = mcMemcpyHostToDevice; static constexpr auto MemcpyDeviceToHost = mcMemcpyDeviceToHost; + static constexpr auto MemcpyDeviceToDevice = mcMemcpyDeviceToDevice; + static constexpr auto Memset = mcMemset; }; diff --git a/src/native/cuda/moore/runtime_.h b/src/native/cuda/moore/runtime_.h index fd02b0d..42227e8 100644 --- a/src/native/cuda/moore/runtime_.h +++ b/src/native/cuda/moore/runtime_.h @@ -17,6 +17,14 @@ struct Runtime static constexpr Device::Type kDeviceType = Device::Type::kMoore; + static constexpr auto SetDevice = musaSetDevice; + + static constexpr auto GetDevice = musaGetDevice; + + static constexpr auto GetDeviceCount = musaGetDeviceCount; + + static constexpr auto DeviceSynchronize = musaDeviceSynchronize; + static constexpr auto Malloc = [](auto&&... args) { return musaMalloc(std::forward(args)...); }; @@ -29,10 +37,14 @@ struct Runtime return musaFree(std::forward(args)...); }; + static constexpr auto MemcpyHostToHost = musaMemcpyHostToHost; + static constexpr auto MemcpyHostToDevice = musaMemcpyHostToDevice; static constexpr auto MemcpyDeviceToHost = musaMemcpyDeviceToHost; + static constexpr auto MemcpyDeviceToDevice = musaMemcpyDeviceToDevice; + static constexpr auto Memset = musaMemset; }; diff --git a/src/native/cuda/nvidia/runtime_.h b/src/native/cuda/nvidia/runtime_.h index c81f5fa..f6a9f2d 100644 --- a/src/native/cuda/nvidia/runtime_.h +++ b/src/native/cuda/nvidia/runtime_.h @@ -19,6 +19,14 @@ struct Runtime static constexpr Device::Type kDeviceType = Device::Type::kNvidia; + static constexpr auto SetDevice = cudaSetDevice; + + static constexpr auto GetDevice = cudaGetDevice; + + static constexpr auto GetDeviceCount = cudaGetDeviceCount; + + static constexpr auto DeviceSynchronize = cudaDeviceSynchronize; + static constexpr auto Malloc = [](auto&&... args) { return cudaMalloc(std::forward(args)...); }; @@ -27,10 +35,14 @@ struct Runtime static constexpr auto Free = cudaFree; + static constexpr auto MemcpyHostToHost = cudaMemcpyHostToHost; + static constexpr auto MemcpyHostToDevice = cudaMemcpyHostToDevice; static constexpr auto MemcpyDeviceToHost = cudaMemcpyDeviceToHost; + static constexpr auto MemcpyDeviceToDevice = cudaMemcpyDeviceToDevice; + static constexpr auto Memset = cudaMemset; }; diff --git a/src/runtime.h b/src/runtime.h index 839477b..ebc2698 100644 --- a/src/runtime.h +++ b/src/runtime.h @@ -1,6 +1,7 @@ #ifndef INFINI_RT_RUNTIME_H_ #define INFINI_RT_RUNTIME_H_ +#include #include #include "device.h" @@ -50,6 +51,29 @@ struct DeviceRuntime : RuntimeBase { } }; +enum class MemcpyKind { + kHostToHost, + kHostToDevice, + kDeviceToHost, + kDeviceToDevice, +}; + +void SetDevice(Device device); + +void GetDevice(Device* device); + +void GetDeviceCount(int* count, Device::Type type); + +void DeviceSynchronize(); + +void Malloc(void** ptr, std::size_t size); + +void Free(void* ptr); + +void Memset(void* ptr, int value, std::size_t count); + +void Memcpy(void* dst, const void* src, std::size_t count, MemcpyKind kind); + } // namespace infini::rt #endif diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index d9ce959..ab54530 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -11,6 +11,10 @@ if(WITH_CPU) add_infini_rt_test(test_cpu_runtime test_cpu_runtime.cc) endif() +if(WITH_CPU OR WITH_NVIDIA) + add_infini_rt_test(test_runtime_dispatch test_runtime_dispatch.cc) +endif() + if(WITH_NVIDIA) add_infini_rt_test(test_nvidia_runtime test_nvidia_runtime.cc) endif() @@ -20,6 +24,13 @@ set(INFINI_RT_TEST_INSTALL_PREFIX set(INFINI_RT_TEST_CONSUMER_BINARY "${CMAKE_CURRENT_BINARY_DIR}/install_consumer_smoke") set(INFINI_RT_TEST_EXTRA_LIBRARY_DIRS "") +set(INFINI_RT_TEST_CONSUMER_BACKEND NONE) + +if(WITH_CPU) + set(INFINI_RT_TEST_CONSUMER_BACKEND CPU) +elseif(WITH_NVIDIA) + set(INFINI_RT_TEST_CONSUMER_BACKEND NVIDIA) +endif() if(WITH_ASCEND) get_filename_component( @@ -48,6 +59,7 @@ add_test( "-DINFINI_RT_CONSUMER_BINARY=${INFINI_RT_TEST_CONSUMER_BINARY}" "-DINFINI_RT_CXX_COMPILER=${CMAKE_CXX_COMPILER}" "-DINFINI_RT_EXTRA_LIBRARY_PATHS=${INFINI_RT_TEST_EXTRA_LIBRARY_PATHS}" + "-DINFINI_RT_CONSUMER_BACKEND=${INFINI_RT_TEST_CONSUMER_BACKEND}" -P "${CMAKE_CURRENT_SOURCE_DIR}/compile_install_consumer.cmake") set_tests_properties(test_install_consumer PROPERTIES FIXTURES_REQUIRED infini_rt_install) diff --git a/tests/compile_install_consumer.cmake b/tests/compile_install_consumer.cmake index bccbde7..169311e 100644 --- a/tests/compile_install_consumer.cmake +++ b/tests/compile_install_consumer.cmake @@ -14,8 +14,14 @@ if(NOT EXISTS "${INFINI_RT_LIBRARY_DIR}/libinfinirt.so") endif() set(INFINI_RT_EXTRA_LINK_ARGS "") +set(INFINI_RT_EXTRA_COMPILE_ARGS "") set(INFINI_RT_LD_LIBRARY_PATH "${INFINI_RT_LIBRARY_DIR}") +if(INFINI_RT_CONSUMER_BACKEND AND NOT INFINI_RT_CONSUMER_BACKEND STREQUAL "NONE") + list(APPEND INFINI_RT_EXTRA_COMPILE_ARGS + "-DINFINI_RT_CONSUMER_BACKEND_${INFINI_RT_CONSUMER_BACKEND}=1") +endif() + if(INFINI_RT_EXTRA_LIBRARY_PATHS) string(REPLACE ":" ";" INFINI_RT_EXTRA_LIBRARY_DIRS "${INFINI_RT_EXTRA_LIBRARY_PATHS}") @@ -34,6 +40,7 @@ execute_process( -std=c++17 -Werror "-I${INFINI_RT_INCLUDE_DIR}" + ${INFINI_RT_EXTRA_COMPILE_ARGS} "${INFINI_RT_CONSUMER_SOURCE}" "-L${INFINI_RT_LIBRARY_DIR}" -linfinirt diff --git a/tests/install_consumer_smoke.cc b/tests/install_consumer_smoke.cc index 19966f6..97442b5 100644 --- a/tests/install_consumer_smoke.cc +++ b/tests/install_consumer_smoke.cc @@ -1,6 +1,7 @@ #include #include +#include #include int main() { @@ -17,5 +18,22 @@ int main() { return 1; } +#if defined(INFINI_RT_CONSUMER_BACKEND_CPU) || \ + defined(INFINI_RT_CONSUMER_BACKEND_NVIDIA) +#if defined(INFINI_RT_CONSUMER_BACKEND_NVIDIA) + const infini::rt::Device runtime_device{infini::rt::Device::Type::kNvidia}; +#else + const infini::rt::Device runtime_device{infini::rt::Device::Type::kCpu}; +#endif + + void* ptr = nullptr; + infini::rt::SetDevice(runtime_device); + infini::rt::Malloc(&ptr, sizeof(std::uint32_t)); + if (ptr == nullptr) { + return 1; + } + infini::rt::Free(ptr); +#endif + return 0; } diff --git a/tests/test_runtime_dispatch.cc b/tests/test_runtime_dispatch.cc new file mode 100644 index 0000000..9be92a7 --- /dev/null +++ b/tests/test_runtime_dispatch.cc @@ -0,0 +1,65 @@ +#include + +#include +#include +#include + +#include "test_helper.h" + +namespace { + +infini::rt::Device RuntimeTestDevice() { +#if defined(WITH_NVIDIA) + return infini::rt::Device{infini::rt::Device::Type::kNvidia}; +#else + return infini::rt::Device{infini::rt::Device::Type::kCpu}; +#endif +} + +} // namespace + +int main() { + infini::rt::test::TestContext context; + const infini::rt::Device device = RuntimeTestDevice(); + std::array input{1, 2, 3, 4}; + std::array output{}; + void* ptr = nullptr; + + infini::rt::SetDevice(device); + + infini::rt::Device current_device; + infini::rt::GetDevice(¤t_device); + context.ExpectEqual(current_device, device, + "Runtime dispatch should keep the current device."); + + int device_count = 0; + infini::rt::GetDeviceCount(&device_count, device.type()); + context.Expect(device_count > 0, + "Runtime dispatch should report at least one device."); + + infini::rt::Malloc(&ptr, input.size()); + context.Expect(ptr != nullptr, "Runtime dispatch should allocate memory."); + if (ptr == nullptr) { + return context.ExitCode(); + } + + infini::rt::Memcpy(ptr, input.data(), input.size(), + infini::rt::MemcpyKind::kHostToDevice); + infini::rt::Memcpy(output.data(), ptr, output.size(), + infini::rt::MemcpyKind::kDeviceToHost); + context.ExpectEqual(output, input, + "Runtime dispatch should copy data through memory."); + + infini::rt::Memset(ptr, 0x5A, output.size()); + infini::rt::Memcpy(output.data(), ptr, output.size(), + infini::rt::MemcpyKind::kDeviceToHost); + for (const auto value : output) { + context.ExpectEqual(value, static_cast(0x5A), + "Runtime dispatch should fill memory."); + } + + infini::rt::DeviceSynchronize(); + infini::rt::Free(ptr); + + return context.ExitCode(); +}