Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ execute_process(
COMMAND ${Python_EXECUTABLE}
${CMAKE_CURRENT_SOURCE_DIR}/scripts/generate_public_headers.py
--output-dir ${CMAKE_CURRENT_SOURCE_DIR}/generated/include
--source-output ${CMAKE_CURRENT_SOURCE_DIR}/generated/src/runtime_dispatch.cc
--devices ${INFINI_RT_PUBLIC_HEADER_DEVICES}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE INFINI_RT_PUBLIC_HEADER_RESULT
Expand Down
246 changes: 246 additions & 0 deletions scripts/generate_public_headers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse
import dataclasses
import pathlib
import re


_DEVICE_HEADERS = {
Expand Down Expand Up @@ -45,6 +47,26 @@
),
}

_DEVICE_TYPES = {
"cpu": "Device::Type::kCpu",
"nvidia": "Device::Type::kNvidia",
"iluvatar": "Device::Type::kIluvatar",
"metax": "Device::Type::kMetax",
"moore": "Device::Type::kMoore",
"cambricon": "Device::Type::kCambricon",
"ascend": "Device::Type::kAscend",
}

_RUNTIME_HEADERS = {
"cpu": "native/cpu/runtime_.h",
"nvidia": "native/cuda/nvidia/runtime_.h",
"iluvatar": "native/cuda/iluvatar/runtime_.h",
"metax": "native/cuda/metax/runtime_.h",
"moore": "native/cuda/moore/runtime_.h",
"cambricon": "native/cambricon/runtime_.h",
"ascend": "native/ascend/runtime_.h",
}


def _guard(path):
token = "_".join(path.parts).replace(".", "_").upper()
Expand Down Expand Up @@ -92,9 +114,230 @@ def _write_generated_header(include_root, devices):
)


@dataclasses.dataclass(frozen=True)
class _Param:
type: str
name: str


@dataclasses.dataclass(frozen=True)
class _Function:
return_type: str
name: str
params: tuple[_Param, ...]

def signature(self):
return f"{self.return_type} {self.name}({self.params_decl()})"

def params_decl(self):
return ", ".join(f"{param.type} {param.name}" for param in self.params)


def _parse_param(param):
param_type, param_name = param.strip().rsplit(" ", 1)

return _Param(param_type, param_name)


def _parse_runtime_functions(runtime_header):
text = pathlib.Path(runtime_header).read_text()
return tuple(
_Function(
return_type,
name,
tuple(_parse_param(param) for param in params.split(", ") if param),
)
for return_type, name, params in re.findall(
r"^(void) ([A-Z]\w*)\(([^()]*)\);$", text, re.MULTILINE
)
)


def _abort_statement(message):
return f""" assert(false && "{message}");
std::abort();"""


def _dispatch_cases(devices, statements):
return "\n".join(
f""" case {_DEVICE_TYPES[device]}: {{
{statements.replace("__DEVICE_TYPE__", _DEVICE_TYPES[device])}
return;
}}"""
for device in devices
)


def _selector(function):
for param in function.params:
if param.type == "Device":
return f"{param.name}.type()"
if param.type == "Device::Type":
return param.name

return "current_device.type()"


def _runtime_arg(param):
if param.type == "Device":
return f"{param.name}.index()"
if param.type == "Device::Type":
return None
if param.type == "MemcpyKind":
return f"RuntimeMemcpyKind<__DEVICE_TYPE__>({param.name})"

return param.name


def _runtime_args(function):
args = (_runtime_arg(param) for param in function.params)

return ", ".join(arg for arg in args if arg is not None)


def _preconditions(function):
required_pointer_names = {
"GetDevice": {"device"},
"GetDeviceCount": {"count"},
}
checks = []
for param in function.params:
if param.type.endswith("**") or param.name in required_pointer_names.get(
function.name, set()
):
checks.append(f" assert({param.name} != nullptr);")

return "\n".join(checks)


def _post_dispatch(function):
if function.name == "SetDevice":
return "\n current_device = device;"

return ""


def _runtime_call(function):
args = _runtime_args(function)
if args:
return f"Runtime<__DEVICE_TYPE__>::{function.name}({args})"

return f"Runtime<__DEVICE_TYPE__>::{function.name}()"


def _write_get_device(function, devices):
device_param = function.params[0].name
cases = _dispatch_cases(
devices,
f""" int index = current_device.index();
CheckCall([&] {{ return Runtime<__DEVICE_TYPE__>::GetDevice(&index); }});
current_device = Device{{current_device.type(), index}};
*{device_param} = current_device;""",
)

return f"""void GetDevice(Device* {device_param}) {{
assert({device_param} != nullptr);

switch (current_device.type()) {{
{cases}
default:
{_abort_statement("runtime device is not enabled")}
}}
}}
"""


def _write_dispatch_function(function, devices):
if function.name == "GetDevice":
return _write_get_device(function, devices)

cases = _dispatch_cases(
devices,
f""" CheckCall([&] {{ return {_runtime_call(function)}; }});{_post_dispatch(function)}""",
)
preconditions = _preconditions(function)
if preconditions:
preconditions = f"{preconditions}\n\n"

return f"""{function.signature()} {{
{preconditions} switch ({_selector(function)}) {{
{cases}
default:
{_abort_statement("runtime device is not enabled")}
}}
}}
"""


def _write_runtime_dispatch(source_path, runtime_header, devices):
first_device_type = _DEVICE_TYPES[devices[0]]
includes = ['#include "runtime.h"']
includes.extend(f'#include "{_RUNTIME_HEADERS[device]}"' for device in devices)
functions = _parse_runtime_functions(runtime_header)
dispatch_functions = "\n".join(
_write_dispatch_function(function, devices) for function in functions
)

source_path.parent.mkdir(parents=True, exist_ok=True)
source_path.write_text(
f"""#include <cassert>
#include <cstdlib>
#include <type_traits>
#include <utility>

{chr(10).join(includes)}

namespace infini::rt {{
namespace {{

thread_local Device current_device{{{first_device_type}, 0}};

template <typename Func>
void CheckCall(Func&& func) {{
using ReturnType = decltype(std::forward<Func>(func)());

if constexpr (std::is_void_v<ReturnType>) {{
std::forward<Func>(func)();
}} else {{
ReturnType status = std::forward<Func>(func)();
if (status != ReturnType{{}}) {{
assert(false && "runtime call failed");
std::abort();
}}
}}
}}

template <Device::Type kDev>
auto RuntimeMemcpyKind(MemcpyKind kind) {{
switch (kind) {{
case MemcpyKind::kHostToHost:
return Runtime<kDev>::MemcpyHostToHost;
case MemcpyKind::kHostToDevice:
return Runtime<kDev>::MemcpyHostToDevice;
case MemcpyKind::kDeviceToHost:
return Runtime<kDev>::MemcpyDeviceToHost;
case MemcpyKind::kDeviceToDevice:
return Runtime<kDev>::MemcpyDeviceToDevice;
}}

assert(false && "unsupported memcpy kind");
std::abort();
return Runtime<kDev>::MemcpyHostToHost;
}}

}} // namespace

{dispatch_functions}
}} // namespace infini::rt
"""
)


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--output-dir", default="generated/include")
parser.add_argument("--source-output", default="generated/src/runtime_dispatch.cc")
parser.add_argument("--runtime-header", default="src/runtime.h")
parser.add_argument("--devices", nargs="+", required=True)
args = parser.parse_args()

Expand All @@ -112,6 +355,9 @@ def main():
_write_wrapper(include_root, wrapper_device, header_name, target)

_write_generated_header(include_root, devices)
_write_runtime_dispatch(
pathlib.Path(args.source_output), args.runtime_header, devices
)


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ add_library(infinirt SHARED)
include(GNUInstallDirs)

file(GLOB BASE_SRCS CONFIGURE_DEPENDS "*.cc")
target_sources(infinirt PRIVATE ${BASE_SRCS})
target_sources(infinirt PRIVATE
${BASE_SRCS}
${PROJECT_SOURCE_DIR}/generated/src/runtime_dispatch.cc)

if(WITH_CPU)
target_compile_definitions(infinirt PUBLIC WITH_CPU=1)
Expand Down
21 changes: 21 additions & 0 deletions src/native/ascend/runtime_.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#ifndef INFINI_RT_ASCEND_RUNTIME__H_
#define INFINI_RT_ASCEND_RUNTIME__H_

#include <cassert>
#include <cstdint>

// clang-format off
#include "acl/acl.h"
// clang-format on
Expand All @@ -17,6 +20,20 @@ struct Runtime<Device::Type::kAscend>

static constexpr Device::Type kDeviceType = Device::Type::kAscend;

static constexpr auto SetDevice = aclrtSetDevice;

static constexpr auto GetDevice = aclrtGetDevice;

static auto GetDeviceCount(int* count) {
Comment thread
voltjia marked this conversation as resolved.
assert(count != nullptr);
std::uint32_t device_count = 0;
auto status = aclrtGetDeviceCount(&device_count);
*count = static_cast<int>(device_count);
return status;
}

static constexpr auto DeviceSynchronize = aclrtSynchronizeDevice;

static constexpr auto Malloc = [](void** ptr, size_t size) {
return aclrtMalloc(ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
};
Expand All @@ -28,10 +45,14 @@ struct Runtime<Device::Type::kAscend>
return aclrtMemcpy(dst, count, src, count, kind);
};

static constexpr auto MemcpyHostToHost = ACL_MEMCPY_HOST_TO_HOST;

static constexpr auto MemcpyHostToDevice = ACL_MEMCPY_HOST_TO_DEVICE;

static constexpr auto MemcpyDeviceToHost = ACL_MEMCPY_DEVICE_TO_HOST;

static constexpr auto MemcpyDeviceToDevice = ACL_MEMCPY_DEVICE_TO_DEVICE;

static constexpr auto Memset = [](void* ptr, int value, size_t count) {
return aclrtMemset(ptr, count, value, count);
};
Expand Down
26 changes: 25 additions & 1 deletion src/native/cambricon/runtime_.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

#include <cnrt.h>

#include <cassert>
#include <cstddef>

#include "native/cambricon/device_.h"
#include "runtime.h"

Expand All @@ -15,16 +18,37 @@ struct Runtime<Device::Type::kCambricon>

static constexpr Device::Type kDeviceType = Device::Type::kCambricon;

static constexpr auto SetDevice = cnrtSetDevice;

static constexpr auto GetDevice = cnrtGetDevice;

static auto GetDeviceCount(int* count) {
Comment thread
voltjia marked this conversation as resolved.
assert(count != nullptr);
unsigned int device_count = 0;
auto status = cnrtGetDeviceCount(&device_count);
*count = static_cast<int>(device_count);
return status;
}

static constexpr auto DeviceSynchronize = cnrtSyncDevice;

static constexpr auto Malloc = cnrtMalloc;

static constexpr auto Free = cnrtFree;

static constexpr auto Memcpy = cnrtMemcpy;
static constexpr auto Memcpy = [](void* dst, const void* src,
Comment thread
voltjia marked this conversation as resolved.
std::size_t size, auto kind) {
return cnrtMemcpy(dst, const_cast<void*>(src), size, kind);
};

static constexpr auto MemcpyHostToHost = cnrtMemcpyHostToHost;

static constexpr auto MemcpyHostToDevice = cnrtMemcpyHostToDev;

static constexpr auto MemcpyDeviceToHost = cnrtMemcpyDevToHost;

static constexpr auto MemcpyDeviceToDevice = cnrtMemcpyDevToDev;

static constexpr auto Memset = cnrtMemset;
};

Expand Down
Loading
Loading