diff --git a/include/infinicore/ops/broadcast_to.hpp b/include/infinicore/ops/broadcast_to.hpp index 4da4fa788..94d98e065 100644 --- a/include/infinicore/ops/broadcast_to.hpp +++ b/include/infinicore/ops/broadcast_to.hpp @@ -1,17 +1,13 @@ #pragma once #include "../device.hpp" +#include "../graph/graph.hpp" #include "common/op.hpp" -#include namespace infinicore::op { -class BroadcastTo { -public: - // Schema: Output(y), Input(x) - using schema = void (*)(Tensor, Tensor); - static void execute(Tensor y, Tensor x); - static common::OpDispatcher &dispatcher(); -}; + +INFINICORE_GRAPH_OP_CLASS(BroadcastTo, Tensor, Tensor); + Tensor broadcast_to(Tensor x, const std::vector &shape); void broadcast_to_(Tensor y, Tensor x); diff --git a/src/infinicore/ops/broadcast_to/broadcast_to.cc b/src/infinicore/ops/broadcast_to/broadcast_to.cc index 1dd5970fb..b6f3079f5 100644 --- a/src/infinicore/ops/broadcast_to/broadcast_to.cc +++ b/src/infinicore/ops/broadcast_to/broadcast_to.cc @@ -3,20 +3,19 @@ namespace infinicore::op { -common::OpDispatcher &BroadcastTo::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; +INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(BroadcastTo); -void BroadcastTo::execute(Tensor y, Tensor x) { +BroadcastTo::BroadcastTo(Tensor y, Tensor x) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x); - infinicore::context::setDevice(y->device()); - dispatcher().lookup(y->device().getType())(y, x); + INFINICORE_GRAPH_OP_DISPATCH(y->device().getType(), y, x); +} + +void BroadcastTo::execute(Tensor y, Tensor x) { + INFINICORE_GRAPH_OP_RECORD_OR_RUN(BroadcastTo, y, x); } Tensor broadcast_to(Tensor x, const std::vector &shape) { Shape target_shape(shape.begin(), shape.end()); - auto y = Tensor::empty(target_shape, x->dtype(), x->device()); broadcast_to_(y, x); return y; diff --git a/src/infinicore/ops/broadcast_to/broadcast_to_infiniop.cc b/src/infinicore/ops/broadcast_to/broadcast_to_infiniop.cc index 72fb588de..57c1ee901 100644 --- a/src/infinicore/ops/broadcast_to/broadcast_to_infiniop.cc +++ b/src/infinicore/ops/broadcast_to/broadcast_to_infiniop.cc @@ -1,61 +1,49 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" +#include "../infiniop_impl.hpp" #include "infinicore/ops/broadcast_to.hpp" -#include "infinicore/ops/common/cache.hpp" -#include namespace infinicore::op::broadcast_to_impl::infiniop { -// 定义描述符缓存 -thread_local common::OpCache caches( - 100, // capacity - [](infiniopBroadcastToDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyBroadcastToDescriptor(desc)); - desc = nullptr; - } - }); - -void calculate(Tensor y, Tensor x) { - size_t seed = hash_combine(y, x); +INFINIOP_CACHABLE_DESCRIPTOR(Descriptor, BroadcastTo, 100); + +struct PlannedMeta { + std::shared_ptr descriptor; + graph::GraphTensor workspace, y, x; +}; - auto device = context::getDevice(); - auto &cache = caches.getCache(device); +void *plan(Tensor y, Tensor x) { + size_t seed = hash_combine(y, x); - auto desc_opt = cache.get(seed); - infiniopBroadcastToDescriptor_t desc = nullptr; + INFINIOP_CACHABLE_DESCRIPTOR_GET_OR_CREATE( + Descriptor, descriptor, BroadcastTo, + seed, + y->desc(), x->desc()); - if (!desc_opt) { - // 2. 创建描述符 - INFINICORE_CHECK_ERROR(infiniopCreateBroadcastToDescriptor( - context::getInfiniopHandle(device), - &desc, - y->desc(), - x->desc())); + INFINIOP_WORKSPACE_TENSOR(workspace, BroadcastTo, descriptor); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } + return new PlannedMeta{ + descriptor, + graph::GraphTensor(workspace), + graph::GraphTensor(y), + graph::GraphTensor(x)}; +} - // 3. 获取 Workspace 并执行 - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetBroadcastToWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); +void run(void *planned_meta) { + auto planned = reinterpret_cast(planned_meta); INFINICORE_CHECK_ERROR(infiniopBroadcastTo( - desc, - workspace->data(), - workspace_size, - y->data(), - x->data(), + planned->descriptor->desc, + planned->workspace->data(), + planned->workspace->numel(), + planned->y->data(), + planned->x->data(), context::getStream())); } -// 4. 注册算子实现 -static bool registered = []() { - BroadcastTo::dispatcher().registerAll(&calculate, false); - return true; -}(); +void cleanup(void **planned_meta_ptr) { + delete *reinterpret_cast(planned_meta_ptr); + *planned_meta_ptr = nullptr; +} + +INFINICORE_GRAPH_OP_REGISTER_ALLDEVICE(BroadcastTo, &plan, &run, &cleanup); } // namespace infinicore::op::broadcast_to_impl::infiniop