diff --git a/include/infinicore/adaptor/aten_adaptor.hpp b/include/infinicore/adaptor/aten_adaptor.hpp index e7a852085..00d5cbec2 100644 --- a/include/infinicore/adaptor/aten_adaptor.hpp +++ b/include/infinicore/adaptor/aten_adaptor.hpp @@ -5,7 +5,7 @@ #include -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) #include #include #include @@ -30,12 +30,12 @@ inline at::ScalarType to_at_dtype(DataType dtype) { } inline at::Device to_at_device(const Device &device) { - if (device.getType() == Device::Type::NVIDIA) { + // PyTorch ATen only exposes standard device types (e.g. kCPU/kCUDA). + // Treat MetaX/QY devices as CUDA devices for ATen tensor interoperability. + if (device.getType() == Device::Type::NVIDIA || device.getType() == Device::Type::METAX || device.getType() == Device::Type::QY) { return at::Device(at::kCUDA, device.getIndex()); } else if (device.getType() == Device::Type::CPU) { return at::Device(at::kCPU); - } else if (device.getType() == Device::Type::QY) { - return at::Device(at::kCUDA, device.getIndex()); } else { throw std::runtime_error("Unsupported device type for ATen"); } @@ -43,7 +43,7 @@ inline at::Device to_at_device(const Device &device) { at::Tensor to_aten_tensor(const infinicore::Tensor &t); -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) c10::cuda::CUDAStream get_cuda_stream(); #endif } // namespace infinicore::adaptor diff --git a/include/infinicore/adaptor/flash_attention_adaptor.hpp b/include/infinicore/adaptor/flash_attention_adaptor.hpp index 8a9e152fd..9ffcc42d6 100644 --- a/include/infinicore/adaptor/flash_attention_adaptor.hpp +++ b/include/infinicore/adaptor/flash_attention_adaptor.hpp @@ -2,7 +2,12 @@ #pragma once #include "aten_adaptor.hpp" +// NVIDIA flash-attn-nvidia.so uses namespace flash. The pip/MetaX flash_attn_2_cuda extension +// exports the same entry points at global scope (no namespace), matching FLASH_NAMESPACE builds +// where the namespace is empty. +#if !defined(ENABLE_METAX_API) namespace flash { +#endif std::vector mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_multiple(head_size, 8) const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x round_multiple(head_size, 8) @@ -39,7 +44,13 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_hea int window_size_right, const float softcap, const bool return_softmax, - std::optional gen_); + std::optional gen_ +#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3) + // MetaX/Mars `flash_attn_2_cuda` (e.g. 2.6.x+mars) appends this argument vs upstream Dao-AILab flash-attn. + , + std::optional &flash_attn_mars_ext_ +#endif + ); std::vector mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8) @@ -108,7 +119,15 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size int window_size_right, const float softcap, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - int num_splits); + int num_splits +#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3) + // MetaX/Mars `flash_attn_2_cuda` (e.g. 2.6.x+mars) appends this argument vs upstream Dao-AILab flash-attn. + , + std::optional &flash_attn_mars_ext_ +#endif + ); +#if !defined(ENABLE_METAX_API) } // namespace flash +#endif #endif // ENABLE_FLASH_ATTN diff --git a/scripts/install.py b/scripts/install.py index 2e420ee9f..98a448254 100644 --- a/scripts/install.py +++ b/scripts/install.py @@ -2,7 +2,10 @@ import subprocess import platform import sys -from set_env import set_env +from set_env import ( + set_env, + set_env_by_config, +) PROJECT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) os.chdir(PROJECT_DIR) @@ -12,11 +15,12 @@ def run_cmd(cmd): def install(xmake_config_flags=""): - run_cmd(f"xmake f {xmake_config_flags} -cv") - run_cmd("xmake") - run_cmd("xmake install") - run_cmd("xmake build infiniop-test") - run_cmd("xmake install infiniop-test") + set_env_by_config(xmake_config_flags) + run_cmd(f"xmake f -y {xmake_config_flags} -cv") + run_cmd("xmake -y") + run_cmd("xmake install -y") + run_cmd("xmake build -y infiniop-test") + run_cmd("xmake install -y infiniop-test") if __name__ == "__main__": diff --git a/scripts/metax_env.py b/scripts/metax_env.py new file mode 100644 index 000000000..df35c57a8 --- /dev/null +++ b/scripts/metax_env.py @@ -0,0 +1,67 @@ +import os + + +def _first_existing_dir(paths: list[str]) -> str: + for p in paths: + if p and os.path.isdir(p): + return p + return "" + + +def _metax_toolkit_root(use_mc: bool) -> str: + """Return toolkit root for MetaX builds (MACA when use-mc; otherwise HPCC).""" + if use_mc: + for key in ("MACA_PATH", "MACA_HOME", "MACA_ROOT"): + v = os.environ.get(key, "").strip() + if v: + return v + return _first_existing_dir(["/opt/maca"]) + return _first_existing_dir(["/opt/hpcc"]) + + +def _prepend_path_var(name: str, prefixes: list[str]) -> None: + """Prepend colon-separated *prefixes* to env var *name* (POSIX).""" + if not prefixes: + return + chunk = ":".join(prefixes) + cur = os.environ.get(name, "") + os.environ[name] = f"{chunk}:{cur}" if cur else chunk + + +def set_env_for_metax_gpu( + flags: str, + *, + parse_xmake_cli_flag_values, + truthy_flag_value, +) -> None: + """ + Prepend compiler include paths needed when building ATen-enabled C++ against torch headers. + + This chooses paths based on xmake backend flags (e.g. --metax-gpu) and toolkit selection + (e.g. MetaX HPCC vs MACA when --use-mc=y). + """ + d = parse_xmake_cli_flag_values(flags) + if not truthy_flag_value(d.get("aten", "n")): + return + + if truthy_flag_value(d.get("metax-gpu", "n")): + use_mc = truthy_flag_value(d.get("use-mc", "n")) + root = _metax_toolkit_root(use_mc=use_mc) + if not root: + return + dirs = [ + os.path.join(root, "tools", "cu-bridge", "include"), + os.path.join(root, "include", "hcr"), + # cu-bridge cuComplex.h includes "hcComplex.h" from HPCC include/common + os.path.join(root, "include", "common"), + # cu-bridge cusparse wrapper includes "hcsparse.h" under include/hcsparse + os.path.join(root, "include", "hcsparse"), + # cu-bridge cublasLt wrapper includes "hcblasLt.h" under include/hcblas + os.path.join(root, "include", "hcblas"), + # cu-bridge cusolver wrapper includes "hcsolver_common.h" under include/hcsolver + os.path.join(root, "include", "hcsolver"), + os.path.join(root, "include"), + ] + for var in ("CPATH", "CPLUS_INCLUDE_PATH", "C_INCLUDE_PATH"): + _prepend_path_var(var, dirs) + return diff --git a/scripts/set_env.py b/scripts/set_env.py index d1d4c2184..f489c28ad 100644 --- a/scripts/set_env.py +++ b/scripts/set_env.py @@ -1,6 +1,46 @@ import os import platform +from metax_env import set_env_for_metax_gpu + + +def _parse_xmake_cli_flag_values(flags: str): + """Parse a string like '--metax-gpu=y --aten=y' into {key: value}.""" + parts = flags.replace("=", " ").split() + d = {} + i = 0 + n = len(parts) + while i < n: + p = parts[i] + if p.startswith("--") and len(p) > 2: + key = p[2:].lower() + i += 1 + if i < n and not parts[i].startswith("--"): + d[key] = parts[i].lower() + i += 1 + else: + d[key] = "y" + else: + i += 1 + return d + + +def _truthy_flag_value(v: str) -> bool: + return v in ("y", "yes", "true", "1", "on") + + +def set_env_by_config(flags: str) -> None: + """Set environment variables for InfiniCore builds with xmake config flags.""" + d = _parse_xmake_cli_flag_values(flags) + if _truthy_flag_value(d.get("metax-gpu", "n")): + set_env_for_metax_gpu( + flags, + parse_xmake_cli_flag_values=_parse_xmake_cli_flag_values, + truthy_flag_value=_truthy_flag_value, + ) + else: + pass + def set_env(): if os.environ.get("INFINI_ROOT") == None: diff --git a/src/infinicore/adaptor/aten_adaptor.cc b/src/infinicore/adaptor/aten_adaptor.cc index 2ffe396ef..04db643f9 100644 --- a/src/infinicore/adaptor/aten_adaptor.cc +++ b/src/infinicore/adaptor/aten_adaptor.cc @@ -32,7 +32,7 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) { options); } -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) c10::cuda::CUDAStream get_cuda_stream() { return c10::cuda::getStreamFromExternal( cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex()); diff --git a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc index 677b85d88..0167c17df 100644 --- a/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc +++ b/src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc @@ -4,6 +4,18 @@ #include +#ifdef ENABLE_FLASH_ATTN +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) +#include +#endif +#endif + +#if defined(ENABLE_METAX_API) +#define INFINICORE_FLASH_OP(name) ::name +#else +#define INFINICORE_FLASH_OP(name) flash::name +#endif + namespace infinicore::op::mha_kvcache_impl::flashattn { struct PlannedMeta { @@ -33,17 +45,24 @@ void *plan(Tensor out, void run(void *planned_meta) { #ifdef ENABLE_FLASH_ATTN +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); +#endif auto *p = reinterpret_cast(planned_meta); - auto out_tensor = infinicore::adaptor::to_aten_tensor(p->out); + // Paged KV caches must be contiguous for flash-attn; avoid extra copies for q/metadata when already dense. + const bool out_need_copy_back = !p->out->is_contiguous(); + Tensor out_work = out_need_copy_back ? p->out->contiguous() : Tensor(p->out); + auto out_tensor = infinicore::adaptor::to_aten_tensor(out_work); auto q = infinicore::adaptor::to_aten_tensor(p->q); -#if defined(ENABLE_NVIDIA_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache); auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache); #elif defined(ENABLE_QY_API) - auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache).contiguous(); - auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache).contiguous(); + Tensor k_cache_work = p->k_cache->contiguous(); + Tensor v_cache_work = p->v_cache->contiguous(); + auto k_cache = infinicore::adaptor::to_aten_tensor(k_cache_work); + auto v_cache = infinicore::adaptor::to_aten_tensor(v_cache_work); #endif auto seqlens_k = std::optional(infinicore::adaptor::to_aten_tensor(p->seqlens_k)); auto block_table = std::optional(infinicore::adaptor::to_aten_tensor(p->block_table)); @@ -65,7 +84,11 @@ void run(void *planned_meta) { auto out = use_dynamic_out ? std::optional(std::nullopt) : std::optional(out_tensor); - auto result = flash::mha_fwd_kvcache( +#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3) + std::optional flash_attn_mars_ext = std::nullopt; +#endif + + auto result = INFINICORE_FLASH_OP(mha_fwd_kvcache)( q, k_cache, v_cache, @@ -85,11 +108,19 @@ void run(void *planned_meta) { -1, 0.0f, false, - 0); + 0 +#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3) + , + flash_attn_mars_ext +#endif + ); if (use_dynamic_out) { out_tensor.copy_(result[0]); } + if (out_need_copy_back) { + p->out->copy_from(out_work); + } #else throw std::runtime_error("FlashAttention is not enabled in this build"); #endif diff --git a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc index aff085898..f80107e7e 100644 --- a/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc +++ b/src/infinicore/ops/multi_head_attention_varlen/mha_varlen_flashattn.cc @@ -4,6 +4,12 @@ #include +#ifdef ENABLE_FLASH_ATTN +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) +#include +#endif +#endif + namespace infinicore::op::mha_varlen_impl::flashattn { struct PlannedMeta { @@ -39,6 +45,20 @@ void *plan(Tensor out, scale}; } +namespace { + +#ifdef ENABLE_FLASH_ATTN +// MetaX/hpcc pip `flash_attn_2_cuda` exports `mha_varlen_fwd` at global scope (no namespace), +// while NVIDIA `flash-attn-nvidia.so` uses `flash::mha_varlen_fwd`. +#if defined(ENABLE_METAX_API) +#define INFINICORE_FLASH_OP(name) ::name +#else +#define INFINICORE_FLASH_OP(name) flash::name +#endif + +#endif // ENABLE_FLASH_ATTN +} // namespace + void run(void *planned_meta) { #ifdef ENABLE_FLASH_ATTN c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream()); @@ -47,7 +67,12 @@ void run(void *planned_meta) { auto q = infinicore::adaptor::to_aten_tensor(p->q); auto k = infinicore::adaptor::to_aten_tensor(p->k); auto v = infinicore::adaptor::to_aten_tensor(p->v); - auto out = std::optional(infinicore::adaptor::to_aten_tensor(p->out)); + + const bool out_need_copy_back = !p->out->is_contiguous(); + Tensor out_work_ic = out_need_copy_back ? p->out->contiguous() : Tensor(p->out); + auto out_work = infinicore::adaptor::to_aten_tensor(out_work_ic); + auto out = std::optional(out_work); + auto cu_seqlens_q = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_q); auto cu_seqlens_kv = infinicore::adaptor::to_aten_tensor(p->cum_seqlens_k); std::optional seqused_k = std::nullopt; @@ -58,7 +83,12 @@ void run(void *planned_meta) { auto alibi_slopes = p->alibi_slopes ? std::optional(infinicore::adaptor::to_aten_tensor(*p->alibi_slopes)) : std::nullopt; auto scale = p->scale; - flash::mha_varlen_fwd( +#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3) + std::optional flash_attn_mars_ext = std::nullopt; +#endif + + INFINICORE_FLASH_OP(mha_varlen_fwd) + ( q, k, v, @@ -79,7 +109,17 @@ void run(void *planned_meta) { -1, 0.0, false, - std::nullopt); + std::nullopt +#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3) + , + flash_attn_mars_ext +#endif + ); + + if (out_need_copy_back) { + p->out->copy_from(out_work_ic); + } + #else throw std::runtime_error("FlashAttention is not enabled in this build"); #endif diff --git a/xmake.lua b/xmake.lua index c69e0170d..ecea63518 100644 --- a/xmake.lua +++ b/xmake.lua @@ -169,6 +169,11 @@ if has_config("metax-gpu") then add_defines("ENABLE_METAX_API") if has_config("use-mc") then add_defines("ENABLE_METAX_MC_API") + -- MACA torch build expects USE_MACA for ATen headers (e.g. C10_WARP_SIZE). + add_defines("USE_MACA") + else + -- HPCC torch build expects this for ATen headers on hpcc. + add_defines("USE_HPCC") end includes("xmake/metax.lua") end @@ -235,14 +240,14 @@ option_end() -- Flash-Attn option("flash-attn") - set_default("") + set_default(nil) set_showmenu(true) set_description("Path to flash-attention repo. If not set, flash-attention will not used.") option_end() if has_config("aten") then add_defines("ENABLE_ATEN") - if get_config("flash-attn") ~= false then + if get_config("flash-attn") and get_config("flash-attn") ~= "" then add_defines("ENABLE_FLASH_ATTN") end end @@ -258,6 +263,7 @@ if has_config("graph") then add_defines("USE_INFINIRT_GRAPH") end + -- InfiniCCL option("ccl") set_default(false) @@ -463,30 +469,48 @@ target("infinicore_cpp_api") add_linkdirs(INFINI_ROOT.."/lib") add_links("infiniop", "infinirt", "infiniccl") - if get_config("flash-attn") ~= "" and get_config("flash-attn") ~= nil then + if get_config("flash-attn") and get_config("flash-attn") ~= "" then add_installfiles("(builddir)/$(plat)/$(arch)/$(mode)/flash-attn*.so", {prefixdir = "lib"}) if has_config("nv-gpu") then add_deps("flash-attn-nvidia") end + if has_config("metax-gpu") then + add_deps("flash-attn-metax") + end if has_config("qy-gpu") then add_deps("flash-attn-qy") end end - if get_config("flash-attn") and get_config("flash-attn") ~= "" and has_config("qy-gpu") then - local flash_so_qy = _qy_flash_attn_cuda_so_path() - local flash_dir_qy = path.directory(flash_so_qy) - local flash_name_qy = path.filename(flash_so_qy) - before_link(function (target) - target:add( - "shflags", - "-Wl,--no-as-needed -L" .. flash_dir_qy .. " -l:" .. flash_name_qy .. " -Wl,-rpath," .. flash_dir_qy, - {force = true} - ) - end) - end + -- Flash pip `.so` link flags: `before_link` runs in an xmake sandbox that cannot see helpers + -- from other included scripts; MetaX and QY each register their own hook in `xmake/metax.lua` + -- and `xmake/qy.lua`. before_build(function (target) + -- MetaX + flash-attn: `flash_attn_2_cuda` may use a different `mha_fwd_kvcache` ABI + -- depending on the underlying stack version. When building with MACA (`--use-mc=y`), + -- the version file is typically `/opt/maca/Version.txt` (HPCC uses `/opt/hpcc/Version.txt`). + if has_config("metax-gpu") and get_config("flash-attn") and get_config("flash-attn") ~= "" then + local version_txt = "/opt/hpcc/Version.txt" + if not os.isfile(version_txt) and has_config("use-mc") then + version_txt = "/opt/maca/Version.txt" + end + if os.isfile(version_txt) then + local content = os.iorunv("cat", {version_txt}) or "" + content = content:trim() + local major_str = content:match("Version:(%d+)") or content:match("^(%d+)") + if major_str and major_str ~= "" then + local major = tonumber(major_str) + if major then + local define = "INFINICORE_HPCC_VERSION_MAJOR=" .. tostring(major) + target:add("defines", define) + target:add("cxflags", "-D" .. define) + target:add("cxxflags", "-D" .. define) + end + end + end + end + if has_config("aten") then local outdata = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim() local TORCH_DIR = outdata @@ -537,7 +561,6 @@ target("infinicore_cpp_api") target_end() target("_infinicore") - add_packages("boost") if is_mode("debug") then add_defines("BOOST_STACKTRACE_USE_BACKTRACE") add_links("backtrace") diff --git a/xmake/metax.lua b/xmake/metax.lua index e7071d4bb..85407ed1b 100644 --- a/xmake/metax.lua +++ b/xmake/metax.lua @@ -1,5 +1,63 @@ local MACA_ROOT = os.getenv("MACA_PATH") or os.getenv("MACA_HOME") or os.getenv("MACA_ROOT") +local FLASH_ATTN_ROOT = get_config("flash-attn") + +-- MetaX flash-attn (pip `flash_attn_2_cuda`) may append an extra trailing argument +-- (`flash_attn_mars_ext_`) depending on the underlying HPCC/MetaX stack version. +do + -- Intentionally empty: HPCC version parsing is deferred to `before_build` + -- on `infinicore_cpp_api` where `os.iorunv` is available in this xmake sandbox. +end + +-- Resolve MetaX flash-attn .so path (used only from this file: `before_link` sandbox cannot see globals from `xmake.lua`). +local FLASH_ATTN_METAX_CUDA_SO_CONTAINER_DEFAULT = + "/opt/conda/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so" + +local function metax_flash_attn_cuda_so_path() + -- Highest priority: override the exact `.so` file to link. + local env_path = os.getenv("FLASH_ATTN_2_CUDA_SO") + if env_path and env_path ~= "" then + env_path = env_path:trim() + if os.isfile(env_path) then + return env_path + end + print(string.format("warning: metax+flash-attn: FLASH_ATTN_2_CUDA_SO is not a file: %s, fallback to container/default path", env_path)) + end + + -- Second priority: allow overriding the "expected" container path via env. + local container_path = os.getenv("FLASH_ATTN_METAX_CUDA_SO_CONTAINER") + if not container_path or container_path == "" then + container_path = FLASH_ATTN_METAX_CUDA_SO_CONTAINER_DEFAULT + end + + if not os.isfile(container_path) then + print( + string.format( + "warning: metax+flash-attn: expected %s; install flash-attn in conda env, or export FLASH_ATTN_2_CUDA_SO.", + container_path + ) + ) + end + return container_path +end + +-- MetaX flash-attn link flags for pip `flash_attn_2_cuda`. +-- Version/ABI macros are set in `xmake.lua` for `infinicore_cpp_api` so they apply to all sources. +target("infinicore_cpp_api") + if get_config("flash-attn") and get_config("flash-attn") ~= "" then + before_link(function (target) + local flash_so_metax = metax_flash_attn_cuda_so_path() + local flash_dir_metax = path.directory(flash_so_metax) + local flash_name_metax = path.filename(flash_so_metax) + target:add( + "shflags", + "-Wl,--no-as-needed -L" .. flash_dir_metax .. " -l:" .. flash_name_metax .. " -Wl,-rpath," .. flash_dir_metax, + {force = true} + ) + end) + end +target_end() + add_includedirs(MACA_ROOT .. "/include") add_linkdirs(MACA_ROOT .. "/lib") if has_config("use-mc") then @@ -57,8 +115,8 @@ target("infiniop-metax") add_includedirs(MACA_ROOT .. "/include/mcr") add_files("../build/ninetoothed/*.c", "../build/ninetoothed/*.cpp", { cxflags = { - "-include stdlib.h", - "-Wno-return-type", + "-include stdlib.h", + "-Wno-return-type", "-Wno-implicit-function-declaration", "-Wno-builtin-declaration-mismatch" } @@ -66,6 +124,27 @@ target("infiniop-metax") end target_end() +target("flash-attn-metax") + set_kind("phony") + set_default(false) + + if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= "" then + before_build(function (target) + local TORCH_DIR = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim() + local PYTHON_INCLUDE = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"}):trim() + local PYTHON_LIB_DIR = os.iorunv("python", {"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"}):trim() + + -- Validate build/runtime env in container and keep these paths available for downstream linking. + target:add("includedirs", TORCH_DIR .. "/include", TORCH_DIR .. "/include/torch/csrc/api/include", PYTHON_INCLUDE, {public = false}) + target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR, {public = false}) + end) + else + before_build(function (target) + print("Flash Attention not available, skipping flash-attn-metax integration") + end) + end +target_end() + target("infinirt-metax") set_kind("static") set_languages("cxx17") diff --git a/xmake/qy.lua b/xmake/qy.lua index 31b65c33c..5b10e111b 100644 --- a/xmake/qy.lua +++ b/xmake/qy.lua @@ -219,3 +219,18 @@ target("flash-attn-qy") end) end target_end() + +if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= "" then + target("infinicore_cpp_api") + before_link(function (target) + local flash_so_qy = _qy_flash_attn_cuda_so_path() + local flash_dir_qy = path.directory(flash_so_qy) + local flash_name_qy = path.filename(flash_so_qy) + target:add( + "shflags", + "-Wl,--no-as-needed -L" .. flash_dir_qy .. " -l:" .. flash_name_qy .. " -Wl,-rpath," .. flash_dir_qy, + {force = true} + ) + end) + target_end() +end