Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions include/infinicore/adaptor/aten_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include <ATen/ATen.h>

#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
Expand All @@ -30,20 +30,20 @@ inline at::ScalarType to_at_dtype(DataType dtype) {
}

inline at::Device to_at_device(const Device &device) {
if (device.getType() == Device::Type::NVIDIA) {
// PyTorch ATen only exposes standard device types (e.g. kCPU/kCUDA).
// Treat MetaX/QY devices as CUDA devices for ATen tensor interoperability.
if (device.getType() == Device::Type::NVIDIA || device.getType() == Device::Type::METAX || device.getType() == Device::Type::QY) {
return at::Device(at::kCUDA, device.getIndex());
} else if (device.getType() == Device::Type::CPU) {
return at::Device(at::kCPU);
} else if (device.getType() == Device::Type::QY) {
return at::Device(at::kCUDA, device.getIndex());
} else {
throw std::runtime_error("Unsupported device type for ATen");
}
}

at::Tensor to_aten_tensor(const infinicore::Tensor &t);

#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
c10::cuda::CUDAStream get_cuda_stream();
#endif
} // namespace infinicore::adaptor
Expand Down
23 changes: 21 additions & 2 deletions include/infinicore/adaptor/flash_attention_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
#pragma once
#include "aten_adaptor.hpp"

// NVIDIA flash-attn-nvidia.so uses namespace flash. The pip/MetaX flash_attn_2_cuda extension
// exports the same entry points at global scope (no namespace), matching FLASH_NAMESPACE builds
// where the namespace is empty.
#if !defined(ENABLE_METAX_API)
namespace flash {
#endif
std::vector<at::Tensor>
mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8)
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8)
Expand Down Expand Up @@ -39,7 +44,13 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_hea
int window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_);
std::optional<at::Generator> gen_
#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
// MetaX/Mars `flash_attn_2_cuda` (e.g. 2.6.x+mars) appends this argument vs upstream Dao-AILab flash-attn.
,
std::optional<at::Tensor> &flash_attn_mars_ext_
#endif
);

std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
Expand Down Expand Up @@ -108,7 +119,15 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size
int window_size_right,
const float softcap,
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
int num_splits);
int num_splits
#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
// MetaX/Mars `flash_attn_2_cuda` (e.g. 2.6.x+mars) appends this argument vs upstream Dao-AILab flash-attn.
,
std::optional<at::Tensor> &flash_attn_mars_ext_
#endif
);

#if !defined(ENABLE_METAX_API)
} // namespace flash
#endif
#endif // ENABLE_FLASH_ATTN
16 changes: 10 additions & 6 deletions scripts/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
import subprocess
import platform
import sys
from set_env import set_env
from set_env import (
set_env,
set_env_by_config,
)

PROJECT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
os.chdir(PROJECT_DIR)
Expand All @@ -12,11 +15,12 @@ def run_cmd(cmd):


def install(xmake_config_flags=""):
run_cmd(f"xmake f {xmake_config_flags} -cv")
run_cmd("xmake")
run_cmd("xmake install")
run_cmd("xmake build infiniop-test")
run_cmd("xmake install infiniop-test")
set_env_by_config(xmake_config_flags)
run_cmd(f"xmake f -y {xmake_config_flags} -cv")
run_cmd("xmake -y")
run_cmd("xmake install -y")
run_cmd("xmake build -y infiniop-test")
run_cmd("xmake install -y infiniop-test")


if __name__ == "__main__":
Expand Down
67 changes: 67 additions & 0 deletions scripts/metax_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os


def _first_existing_dir(paths: list[str]) -> str:
for p in paths:
if p and os.path.isdir(p):
return p
return ""


def _metax_toolkit_root(use_mc: bool) -> str:
"""Return toolkit root for MetaX builds (MACA when use-mc; otherwise HPCC)."""
if use_mc:
for key in ("MACA_PATH", "MACA_HOME", "MACA_ROOT"):
v = os.environ.get(key, "").strip()
if v:
return v
return _first_existing_dir(["/opt/maca"])
return _first_existing_dir(["/opt/hpcc"])


def _prepend_path_var(name: str, prefixes: list[str]) -> None:
"""Prepend colon-separated *prefixes* to env var *name* (POSIX)."""
if not prefixes:
return
chunk = ":".join(prefixes)
cur = os.environ.get(name, "")
os.environ[name] = f"{chunk}:{cur}" if cur else chunk


def set_env_for_metax_gpu(
flags: str,
*,
parse_xmake_cli_flag_values,
truthy_flag_value,
) -> None:
"""
Prepend compiler include paths needed when building ATen-enabled C++ against torch headers.

This chooses paths based on xmake backend flags (e.g. --metax-gpu) and toolkit selection
(e.g. MetaX HPCC vs MACA when --use-mc=y).
"""
d = parse_xmake_cli_flag_values(flags)
if not truthy_flag_value(d.get("aten", "n")):
return

if truthy_flag_value(d.get("metax-gpu", "n")):
use_mc = truthy_flag_value(d.get("use-mc", "n"))
root = _metax_toolkit_root(use_mc=use_mc)
if not root:
return
dirs = [
os.path.join(root, "tools", "cu-bridge", "include"),
os.path.join(root, "include", "hcr"),
# cu-bridge cuComplex.h includes "hcComplex.h" from HPCC include/common
os.path.join(root, "include", "common"),
# cu-bridge cusparse wrapper includes "hcsparse.h" under include/hcsparse
os.path.join(root, "include", "hcsparse"),
# cu-bridge cublasLt wrapper includes "hcblasLt.h" under include/hcblas
os.path.join(root, "include", "hcblas"),
# cu-bridge cusolver wrapper includes "hcsolver_common.h" under include/hcsolver
os.path.join(root, "include", "hcsolver"),
os.path.join(root, "include"),
]
for var in ("CPATH", "CPLUS_INCLUDE_PATH", "C_INCLUDE_PATH"):
_prepend_path_var(var, dirs)
return
40 changes: 40 additions & 0 deletions scripts/set_env.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,46 @@
import os
import platform

from metax_env import set_env_for_metax_gpu


def _parse_xmake_cli_flag_values(flags: str):
"""Parse a string like '--metax-gpu=y --aten=y' into {key: value}."""
parts = flags.replace("=", " ").split()
d = {}
i = 0
n = len(parts)
while i < n:
p = parts[i]
if p.startswith("--") and len(p) > 2:
key = p[2:].lower()
i += 1
if i < n and not parts[i].startswith("--"):
d[key] = parts[i].lower()
i += 1
else:
d[key] = "y"
else:
i += 1
return d


def _truthy_flag_value(v: str) -> bool:
return v in ("y", "yes", "true", "1", "on")


def set_env_by_config(flags: str) -> None:
"""Set environment variables for InfiniCore builds with xmake config flags."""
d = _parse_xmake_cli_flag_values(flags)
if _truthy_flag_value(d.get("metax-gpu", "n")):
set_env_for_metax_gpu(
flags,
parse_xmake_cli_flag_values=_parse_xmake_cli_flag_values,
truthy_flag_value=_truthy_flag_value,
)
else:
pass


def set_env():
if os.environ.get("INFINI_ROOT") == None:
Expand Down
2 changes: 1 addition & 1 deletion src/infinicore/adaptor/aten_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
options);
}

#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
c10::cuda::CUDAStream get_cuda_stream() {
return c10::cuda::getStreamFromExternal(
cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());
Expand Down
43 changes: 37 additions & 6 deletions src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,18 @@

#include <stdexcept>

#ifdef ENABLE_FLASH_ATTN
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
#include <c10/cuda/CUDAGuard.h>
#endif
#endif

#if defined(ENABLE_METAX_API)
#define INFINICORE_FLASH_OP(name) ::name
#else
#define INFINICORE_FLASH_OP(name) flash::name
#endif

namespace infinicore::op::mha_kvcache_impl::flashattn {

struct PlannedMeta {
Expand Down Expand Up @@ -33,17 +45,24 @@ void *plan(Tensor out,

void run(void *planned_meta) {
#ifdef ENABLE_FLASH_ATTN
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
#endif
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);

auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out);
// Paged KV caches must be contiguous for flash-attn; avoid extra copies for q/metadata when already dense.
const bool out_need_copy_back = !p->out->is_contiguous();
Tensor out_work = out_need_copy_back ? p->out->contiguous() : Tensor(p->out);
auto out_tensor = infinicore::adaptor::to_aten_tensor(out_work);
auto q = infinicore::adaptor::to_aten_tensor(p->q);
#if defined(ENABLE_NVIDIA_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API)
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache);
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache);
#elif defined(ENABLE_QY_API)
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache).contiguous();
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache).contiguous();
Tensor k_cache_work = p->k_cache->contiguous();
Tensor v_cache_work = p->v_cache->contiguous();
auto k_cache = infinicore::adaptor::to_aten_tensor(k_cache_work);
auto v_cache = infinicore::adaptor::to_aten_tensor(v_cache_work);
#endif
auto seqlens_k = std::optional<const at::Tensor>(infinicore::adaptor::to_aten_tensor(p->seqlens_k));
auto block_table = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->block_table));
Expand All @@ -65,7 +84,11 @@ void run(void *planned_meta) {
auto out = use_dynamic_out ? std::optional<at::Tensor>(std::nullopt)
: std::optional<at::Tensor>(out_tensor);

auto result = flash::mha_fwd_kvcache(
#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
std::optional<at::Tensor> flash_attn_mars_ext = std::nullopt;
#endif

auto result = INFINICORE_FLASH_OP(mha_fwd_kvcache)(
q,
k_cache,
v_cache,
Expand All @@ -85,11 +108,19 @@ void run(void *planned_meta) {
-1,
0.0f,
false,
0);
0
#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
,
flash_attn_mars_ext
#endif
);

if (use_dynamic_out) {
out_tensor.copy_(result[0]);
}
if (out_need_copy_back) {
p->out->copy_from(out_work);
}
#else
throw std::runtime_error("FlashAttention is not enabled in this build");
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

#include <stdexcept>

#ifdef ENABLE_FLASH_ATTN
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
#include <c10/cuda/CUDAGuard.h>
#endif
#endif

namespace infinicore::op::mha_varlen_impl::flashattn {

struct PlannedMeta {
Expand Down Expand Up @@ -39,6 +45,20 @@ void *plan(Tensor out,
scale};
}

namespace {

#ifdef ENABLE_FLASH_ATTN
// MetaX/hpcc pip `flash_attn_2_cuda` exports `mha_varlen_fwd` at global scope (no namespace),
// while NVIDIA `flash-attn-nvidia.so` uses `flash::mha_varlen_fwd`.
#if defined(ENABLE_METAX_API)
#define INFINICORE_FLASH_OP(name) ::name
#else
#define INFINICORE_FLASH_OP(name) flash::name
#endif

#endif // ENABLE_FLASH_ATTN
} // namespace

void run(void *planned_meta) {
#ifdef ENABLE_FLASH_ATTN
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
Expand All @@ -47,7 +67,12 @@ void run(void *planned_meta) {
auto q = infinicore::adaptor::to_aten_tensor(p->q);
auto k = infinicore::adaptor::to_aten_tensor(p->k);
auto v = infinicore::adaptor::to_aten_tensor(p->v);
auto out = std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(p->out));

const bool out_need_copy_back = !p->out->is_contiguous();
Tensor out_work_ic = out_need_copy_back ? p->out->contiguous() : Tensor(p->out);
auto out_work = infinicore::adaptor::to_aten_tensor(out_work_ic);
auto out = std::optional<at::Tensor>(out_work);

auto cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q);
auto cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k);
std::optional<at::Tensor> seqused_k = std::nullopt;
Expand All @@ -58,7 +83,12 @@ void run(void *planned_meta) {
auto alibi_slopes = p->alibi_slopes ? std::optional<at::Tensor>(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) : std::nullopt;
auto scale = p->scale;

flash::mha_varlen_fwd(
#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
std::optional<at::Tensor> flash_attn_mars_ext = std::nullopt;
#endif

INFINICORE_FLASH_OP(mha_varlen_fwd)
(
q,
k,
v,
Expand All @@ -79,7 +109,17 @@ void run(void *planned_meta) {
-1,
0.0,
false,
std::nullopt);
std::nullopt
#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
,
flash_attn_mars_ext
#endif
);

if (out_need_copy_back) {
p->out->copy_from(out_work_ic);
}

#else
throw std::runtime_error("FlashAttention is not enabled in this build");
#endif
Expand Down
Loading
Loading