diff --git a/include/infiniop/ops/awq_marlin_gemm.h b/include/infiniop/ops/awq_marlin_gemm.h new file mode 100644 index 000000000..2d3fefcb5 --- /dev/null +++ b/include/infiniop/ops/awq_marlin_gemm.h @@ -0,0 +1,46 @@ +#ifndef __INFINIOP_AWQ_MARLIN_GEMM_API_H__ +#define __INFINIOP_AWQ_MARLIN_GEMM_API_H__ + +#include "../operator_descriptor.h" +#include + +typedef struct InfiniopDescriptor *infiniopAwqMarlinGemmDescriptor_t; + +__INFINI_C __export infiniStatus_t infiniopCreateAwqMarlinGemmDescriptor(infiniopHandle_t handle, + infiniopAwqMarlinGemmDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_bias_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t a_scales_desc, + infiniopTensorDescriptor_t global_scales_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t g_idx_desc, + infiniopTensorDescriptor_t perm_desc); + +__INFINI_C __export infiniStatus_t infiniopGetAwqMarlinGemmWorkspaceSize(infiniopAwqMarlinGemmDescriptor_t desc, size_t *size); + +__INFINI_C __export infiniStatus_t infiniopAwqMarlinGemm(infiniopAwqMarlinGemmDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *b_bias, + void *b_scales, + void *a_scales, + void *global_scales, + void *b_zeros, + void *g_idx, + void *perm, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float, + void *stream); + +__INFINI_C __export infiniStatus_t infiniopDestroyAwqMarlinGemmDescriptor(infiniopAwqMarlinGemmDescriptor_t desc); + +#endif diff --git a/src/infiniop/ops/awq_marlin_gemm/awq_marlin_gemm.h b/src/infiniop/ops/awq_marlin_gemm/awq_marlin_gemm.h new file mode 100644 index 000000000..f1df9e736 --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/awq_marlin_gemm.h @@ -0,0 +1,57 @@ +#ifndef AWQ_MARLIN_GEMM_H +#define AWQ_MARLIN_GEMM_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::awq_marlin_gemm::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + AwqMarlinGemmInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + AwqMarlinGemmInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t a_desc, \ + infiniopTensorDescriptor_t b_desc, \ + infiniopTensorDescriptor_t b_bias_desc, \ + infiniopTensorDescriptor_t b_scales_desc, \ + infiniopTensorDescriptor_t a_scales_desc, \ + infiniopTensorDescriptor_t global_scales_desc, \ + infiniopTensorDescriptor_t b_zeros_desc, \ + infiniopTensorDescriptor_t g_idx_desc, \ + infiniopTensorDescriptor_t perm_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *c, \ + const void *a, const void *b, \ + void *b_bias, void *b_scales, void *a_scales, void *global_scales, \ + void *b_zeros, void *g_idx, void *perm, \ + int64_t b_q_type_id, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float, \ + void *stream) const; \ + }; \ + } + +#endif // AWQ_MARLIN_GEMM_H diff --git a/src/infiniop/ops/awq_marlin_gemm/core/scalar_type.hpp b/src/infiniop/ops/awq_marlin_gemm/core/scalar_type.hpp new file mode 100644 index 000000000..1e400c1c3 --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/core/scalar_type.hpp @@ -0,0 +1,252 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace vllm { + +class ScalarType { +public: + enum NanRepr : uint8_t { + NAN_NONE = 0, + NAN_IEEE_754 = 1, + NAN_EXTD_RANGE_MAX_MIN = 2, + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_, + int32_t bias, bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr) {} + + // ----------------------- + // Integer + // ----------------------- + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); + } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); + } + + // ----------------------- + // Floating point(constexpr安全:不做检查) + // ----------------------- + static constexpr ScalarType float_IEEE754(uint8_t exponent, + uint8_t mantissa) { + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, + bool finite_values_only, + NanRepr nan_repr) { + return ScalarType(exponent, mantissa, true, 0, + finite_values_only, nan_repr); + } + + // ----------------------- + // Runtime checked(可选) + // ----------------------- + static inline ScalarType float_checked(uint8_t exponent, + uint8_t mantissa, + bool finite_values_only, + NanRepr nan_repr) { + if (!(nan_repr < NAN_REPR_ID_MAX)) { + throw std::runtime_error("Invalid NanRepr"); + } + + if (!(mantissa > 0 && exponent > 0)) { + throw std::runtime_error("mantissa/exponent must > 0"); + } + + if (nan_repr == NAN_IEEE_754) { + throw std::runtime_error("use float_IEEE754"); + } + + return float_(exponent, mantissa, finite_values_only, nan_repr); + } + + uint8_t const exponent; + uint8_t const mantissa; + bool const signed_; + int32_t const bias; + + bool const finite_values_only; + NanRepr const nan_repr; + + using Id = int64_t; + +private: + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same::value ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, + Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + } + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, + finite_values_only, nan_repr); + } + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy.reduce_members(f, init); + } + + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { + return acc + member_id_field_width(); + }, + 0); + } + +public: + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, + "ScalarType id too large"); + + auto fn = [](std::pair result, auto member) { + auto [id, offset] = result; + constexpr auto bits = member_id_field_width(); + return std::pair{ + id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << offset, + offset + bits}; + }; + + return reduce_members(fn, std::pair{}).first; + } + + static constexpr ScalarType from_id(Id id) { + auto fn = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, offset] = result; + constexpr auto bits = member_id_field_width(); + auto val = static_cast((id >> offset) & ((uint64_t(1) << bits) - 1)); + return std::pair{std::tuple_cat(tuple, std::make_tuple(val)), offset + bits}; + }; + + auto [args, _] = reduce_member_types(fn, std::pair, int>{}); + + return std::apply([](auto... xs) { return ScalarType(xs...); }, args); + } + + constexpr int64_t size_bits() const { + return mantissa + exponent + (signed_ ? 1 : 0); + } + + constexpr bool is_signed() const { return signed_; } + constexpr bool is_integer() const { return exponent == 0; } + constexpr bool is_floating_point() const { return exponent > 0; } + + constexpr bool is_ieee_754() const { + return is_floating_point() && !finite_values_only && nan_repr == NAN_IEEE_754; + } + + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + + constexpr bool has_infs() const { + return is_floating_point() && !finite_values_only; + } + + constexpr bool has_bias() const { return bias != 0; } + + std::string str() const { + if (is_floating_point()) { + auto ret = "float" + std::to_string(size_bits()) + "_e" + std::to_string(exponent) + "m" + std::to_string(mantissa); + + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = (signed_ ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + constexpr bool operator==(ScalarType const &other) const { + return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ && finite_values_only == other.finite_values_only && nan_repr == other.nan_repr; + } +}; + +using ScalarTypeId = ScalarType::Id; + +// ----------------------- +// 原始常量(完全保留) +// ----------------------- + +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE2M1f = ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE8M0fnu = ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// 🔥 关键:alias(不能丢!) + +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// ⭐ 这些就是你报错缺失的 +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +static inline constexpr auto kFloat16Id = kFloat16.id(); + +} // namespace vllm diff --git a/src/infiniop/ops/awq_marlin_gemm/core/source_location.h b/src/infiniop/ops/awq_marlin_gemm/core/source_location.h new file mode 100644 index 000000000..9a06fb380 --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/core/source_location.h @@ -0,0 +1,40 @@ +/// \file source_location.h +/// \brief Portable `source_location` wrapper. +/// +/// Uses `std::source_location` when available (C++20), otherwise falls +/// back to a minimal stub that returns empty/zero values. + +#pragma once +#include + +/// NOTE: fallback to a minimal source_location implementation +#if defined(__cpp_lib_source_location) +#include + +using source_location_t = std::source_location; + +#else + +struct source_location_fallback { +public: + static constexpr source_location_fallback current() noexcept { + return source_location_fallback{}; + } + constexpr source_location_fallback() noexcept = default; + constexpr unsigned line() const noexcept { + return 0; + } + constexpr unsigned column() const noexcept { + return 0; + } + constexpr const char *file_name() const noexcept { + return ""; + } + constexpr const char *function_name() const noexcept { + return ""; + } +}; + +using source_location_t = source_location_fallback; + +#endif diff --git a/src/infiniop/ops/awq_marlin_gemm/core/utils.h b/src/infiniop/ops/awq_marlin_gemm/core/utils.h new file mode 100644 index 000000000..17f985f59 --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/core/utils.h @@ -0,0 +1,216 @@ +/// \file utils.h +/// \brief Host-side C++ utilities used by JIT kernel wrappers. +/// +/// Provides: +/// - `DebugInfo` - wraps `std::source_location` for error reporting. +/// - `RuntimeCheck` - runtime assertion with formatted error messages. +/// - `Panic` - unconditional abort with formatted error messages. +/// - `pointer::offset` - safe void-pointer arithmetic (host side). +/// - `div_ceil` - integer ceiling division. +/// - `dtype_bytes` - byte width of a `DLDataType`. +/// - `irange` - Python-style integer range for range-for loops. + +#pragma once + +// ref: https://forums.developer.nvidia.com/t/c-20s-source-location-compilation-error-when-using-nvcc-12-1/258026/3 +#ifdef __CUDACC__ +#include +#if CUDA_VERSION <= 12010 + +#pragma push_macro("__cpp_consteval") +#pragma push_macro("_NODISCARD") +#pragma push_macro("__builtin_LINE") + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wbuiltin-macro-redefined" +#define __cpp_consteval 201811L +#pragma clang diagnostic pop + +#ifdef _NODISCARD +#undef _NODISCARD +#define _NODISCARD +#endif + +#define consteval constexpr + +#include "source_location.h" + +#undef consteval +#pragma pop_macro("__cpp_consteval") +#pragma pop_macro("_NODISCARD") +#else // __CUDACC__ && CUDA_VERSION > 12010 +#include "source_location.h" +#endif +#else // no __CUDACC__ +#include "source_location.h" +#endif + +#include +#include +// #include +#include +#include +// #include +#include +#include + +namespace host { + +template +inline constexpr bool dependent_false_v = false; + +/// \brief Source-location wrapper for debug/error messages. +struct DebugInfo : public source_location_t { + DebugInfo(source_location_t loc = source_location_t::current()) : source_location_t(loc) {} +}; + +/// \brief Exception type thrown by `RuntimeCheck` and `Panic`. +struct PanicError : public std::runtime_error { +public: + explicit PanicError(std::string msg) : runtime_error(msg), m_message(std::move(msg)) {} + auto root_cause() const -> std::string_view { + const auto str = std::string_view{m_message}; + const auto pos = str.find(": "); + return pos == std::string_view::npos ? str : str.substr(pos + 2); + } + +private: + std::string m_message; +}; + +/// \brief Unconditionally abort with a formatted error message. +template +[[noreturn]] inline auto panic(DebugInfo location, Args &&...args) -> void { + std::ostringstream os; + os << "Runtime check failed at " << location.file_name() << ":" << location.line(); + if constexpr (sizeof...(args) > 0) { + os << ": "; + (os << ... << std::forward(args)); + } else { + os << " in " << location.function_name(); + } + throw PanicError(std::move(os).str()); +} + +/** + * \brief Runtime assertion: panics with a formatted message when `condition` + * is false. Extra `args` are streamed to the error message. + * + * Example: + * \code + * RuntimeCheck(n > 0, "n must be positive, got ", n); + * \endcode + */ +template +struct RuntimeCheck { + template + explicit RuntimeCheck(Cond &&condition, Args &&...args, DebugInfo location = {}) { + if (condition) { + return; + } + [[unlikely]] ::host::panic(location, std::forward(args)...); + } + template + explicit RuntimeCheck(DebugInfo location, Cond &&condition, Args &&...args) { + if (condition) { + return; + } + [[unlikely]] ::host::panic(location, std::forward(args)...); + } +}; + +template +struct Panic { + explicit Panic(Args &&...args, DebugInfo location = {}) { + ::host::panic(location, std::forward(args)...); + } + explicit Panic(DebugInfo location, Args &&...args) { + ::host::panic(location, std::forward(args)...); + } + [[noreturn]] ~Panic() { + std::terminate(); + } +}; + +template +explicit RuntimeCheck(Cond &&, Args &&...) -> RuntimeCheck; + +template +explicit RuntimeCheck(DebugInfo, Cond &&, Args &&...) -> RuntimeCheck; + +template +explicit Panic(Args &&...) -> Panic; + +template +explicit Panic(DebugInfo, Args &&...) -> Panic; + +namespace pointer { + +// we only allow void * pointer arithmetic for safety + +template ::value && ...)>> +inline auto offset(void *ptr, U... offset) -> void * { + return static_cast(ptr) + (... + offset); +} + +template ::value && ...)>> +inline auto offset(const void *ptr, U... offset) -> const void * { + return static_cast(ptr) + (... + offset); +} + +} // namespace pointer + +/// \brief Integer ceiling division: ceil(a / b). +template +inline constexpr auto div_ceil(T a, U b) { + static_assert(std::is_integral::value, "T must be integral"); + static_assert(std::is_integral::value, "U must be integral"); + return (a + b - 1) / b; +} + +/// \brief Returns the byte width of a DLPack data type. +inline auto dtype_bytes(DLDataType dtype) -> std::size_t { + return static_cast(dtype.bits / 8); +} + +// ====================== 修复开始:纯 C++11 兼容版 irange ====================== +// 移除所有 std::ranges / std::integral,完全兼容旧版 CUDA 编译器 + +template +struct IntegerRange { + T start_; + T end_; + + struct Iterator { + T value; + + T operator*() const { return value; } + Iterator &operator++() { + ++value; + return *this; + } + bool operator!=(const Iterator &other) const { + return value != other.value; + } + }; + + Iterator begin() const { return {start_}; } + Iterator end() const { return {end_}; } +}; + +/// Python-style integer range: irange(n) -> [0, n) +template +IntegerRange irange(T end) { + return {0, end}; +} + +/// Python-style integer range: irange(start, end) -> [start, end) +template +IntegerRange irange(T start, T end) { + return {start, end}; +} +// ====================== 修复结束 ====================== + +} // namespace host diff --git a/src/infiniop/ops/awq_marlin_gemm/info.h b/src/infiniop/ops/awq_marlin_gemm/info.h new file mode 100644 index 000000000..9c74b413f --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/info.h @@ -0,0 +1,61 @@ +#ifndef __AWQ_MARLIN_GEMM_INFO_H__ +#define __AWQ_MARLIN_GEMM_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include "marlin/marlin.cuh" +#include + +#include + +namespace op::awq_marlin_gemm { + +class AwqMarlinGemmInfo { + AwqMarlinGemmInfo() = default; + +public: + infiniDtype_t a_dtype, b_dtype, c_dtype, s_dtype; + size_t size_m, size_k, size_n; + int num_groups; + size_t b_q_size_0, b_q_size_1, b_zeros_size_1; + ptrdiff_t a_stride_0; + + static utils::Result create( + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_bias_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t a_scales_desc, + infiniopTensorDescriptor_t global_scales_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t g_idx_desc, + infiniopTensorDescriptor_t perm_desc) { + CHECK_OR_RETURN( + out_desc != nullptr && a_desc != nullptr && b_desc != nullptr && b_scales_desc != nullptr, + INFINI_STATUS_NULL_POINTER); + const infiniDtype_t a_dtype = a_desc->dtype(); + const infiniDtype_t b_dtype = b_desc->dtype(); + const infiniDtype_t c_dtype = out_desc->dtype(); + const infiniDtype_t s_dtype = b_scales_desc->dtype(); + + size_t size_m = a_desc->dim(0); + size_t size_k = a_desc->dim(1); + size_t size_n = out_desc->dim(1); + + int num_groups = static_cast(b_scales_desc->dim(0)); + size_t b_q_size_0 = b_desc->dim(0); + size_t b_q_size_1 = b_desc->dim(1); + size_t b_zeros_size_1 = b_zeros_desc->dim(1); + ptrdiff_t a_stride_0 = a_desc->strides()[0]; + + return utils::Result( + AwqMarlinGemmInfo{a_dtype, b_dtype, c_dtype, s_dtype, + size_m, size_k, size_n, + num_groups, b_q_size_0, b_q_size_1, b_zeros_size_1, a_stride_0}); + } +}; + +} // namespace op::awq_marlin_gemm + +#endif // __AWQ_MARLIN_GEMM_INFO_H__ diff --git a/src/infiniop/ops/awq_marlin_gemm/marlin/dequant.h b/src/infiniop/ops/awq_marlin_gemm/marlin/dequant.h new file mode 100644 index 000000000..1340b6bbe --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/marlin/dequant.h @@ -0,0 +1,601 @@ +/* +Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16) + +The process of fast dequantization can be summarized as a combination +of bitwise operations and floating-point computations: + +weight =>(bit_op / bitwise operations)=> +f16_value =>(flop / floating-point computation)=> +dequantized_weight + +Since the dequantized weights typically require subtracting the zero point and +applying a scale factor, the floating-point computation step can be fused with +the zero-point subtraction and scaling operations. + +The following are the parts that need to be modified for the fused operation +of zero-point subtraction and scaling. + +## INT4 => FP16/BF16 or INT8 => FP16 + +The floating-point computation is `__hsub2` + +If has zero points: + + flop(bit_op(weight)) - flop(bit_op(zp)) + = sub(bit_op(weight), bias) - sub(bit_op(zp), bias) + = bit_op(weight) - bit_op(zp) + +so we don't need additional modification. + +If has float zero points: + + flop(bit_op(weight)) - fzp + = sub(bit_op(weight), bias) - fzp + = bit_op(weight) - (fzp + bias) + +where the `fzp + bias` can be computed at weight loading. But this +may have accuracy issue, so we should not use this in most cases. + +If has not zero points: + + scale(flop(bit_op(weight))) + = scale(sub(bit_op(weight), bias)) + = scale(bit_op(weight)) - scale(bias) + = fma(bit_op(weight), scale_factor, scale(bias)) + +where the `scale(bias)` can be cached. But this may have accuracy issue, +so we should not use this in most cases. + + +## INT8 => BF16 + +INT8 => BF16 is a special case, it use byte_perm instead of flop. +We cannot fused byte_perm with scaling. + + +## FP4/FP8 => FP16/BF16 + + scale(flop(bit_op(weight))) + = scale(mul(bit_op(weight), multiplier)) + = mul(bit_op(weight), scale_factor * multiplier) + +where `scale_factor * multiplier` can be computed at weight loading. + +*/ + +#include "marlin_dtypes.cuh" + +namespace MARLIN_NAMESPACE_NAME { + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750 +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline void dequant(int q, scalar_t2 *frag_b); + +// +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// +template <> +__device__ inline void dequant(int q, + half2 *frag_b) { + const int MASK = 0x000f000f; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, + half2 *frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, + half2 *frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, + half2 *frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162 *frag_b) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + // clang-format on + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162 *frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43084308; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162 *frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162 *frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43004300; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); +} + +// +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +// +template <> +__device__ inline void dequant(int q, + half2 *frag_b) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant( + int q, half2 *frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + frag_b[0] = __hsub2(frag_b[0], + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template <> +__device__ inline void dequant(int q, + half2 *frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, + half2 *frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + frag_b[0] = __hsub2(frag_b[0], + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162 *frag_b) { + float fp32_intermediates[4]; + uint32_t *fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t *bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162 *frag_b) { + float fp32_intermediates[4]; + uint32_t *fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t *bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); +} + +template <> +__device__ inline void dequant( + int q, half2 *frag_b) { + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, half2 *frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162 *frag_b) { + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162 *frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to bfloat162 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, + half2 *frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, half2 *frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162 *frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162 *frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and BF16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kFE2M1f.id(), true>( + int q, __nv_fp8x4_e4m3 *frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP8_EXPONENT = 4; + constexpr int RIGHT_SHIFT = FP8_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70707070; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80808080) | ((q & MASK) >> RIGHT_SHIFT); + + // Note1: reverse indexing is intentional because weights are permuted + // Note2: when dequant to 8bit type, we write to `frag_b[2]` instead of + // `frag_b[1]` to fit the layout of tensorcore + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, int32_t *frag_b) { + constexpr int repeated_zp = 0x08080808; + constexpr int MASK = 0x80808080; + + frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK; + q >>= 4; + frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK; +} + +template <> +__device__ inline void dequant<__nv_fp8x4_e4m3, vllm::kU4B8.id(), true>( + int q, __nv_fp8x4_e4m3 *frag_b) { + int s = q & 0x08080808; + int Out1 = ((q & 0x07070707) | (s << 4)) + (s >> 3); + q >>= 4; + s = q & 0x08080808; + int Out2 = ((q & 0x07070707) | (s << 4)) + (s >> 3); + + frag_b[0] = *reinterpret_cast(&Out1); + frag_b[1] = *reinterpret_cast(&Out2); +} + +template +__device__ inline void dequant_fp8_scales(int q, scalar_t2 *frag_b); + +template <> +__device__ inline void dequant_fp8_scales( + int q, half2 *frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +template <> +__device__ inline void dequant_fp8_scales( + int q, nv_bfloat162 *frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant_fp8_scales( + int q, nv_bfloat162 *frag_b) { + // In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16, + // but we assume that such a extreme value would not occur in real models. + int Out1 = (q & 0xFF00FF00) >> 1; + q <<= 7; + int Out2 = q & 0x7F807F80; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +// subtract zero point in quanted format and then dequant +template +__device__ inline void sub_zp_and_dequant(int q, scalar_t2 *frag_b, int zp); + +template <> +__device__ inline void sub_zp_and_dequant( + int q, int32_t *frag_b, int zp) { + // INT4 with zp -> INT8 + // see https://github.com/vllm-project/vllm/pull/24722 + int repeated_zp = 0x01010101 * zp; + int MASK = 0x80808080; + + frag_b[0] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK; + q >>= 4; + frag_b[1] = ((q & 0x0F0F0F0F | MASK) - repeated_zp) ^ MASK; +} + +template <> +__device__ inline void sub_zp_and_dequant<__nv_fp8x4_e4m3, vllm::kU4.id(), + true>(int q, __nv_fp8x4_e4m3 *frag_b, + int zp) { + // INT4 with zp -> FP8 + // see https://github.com/vllm-project/vllm/pull/24722 + uint32_t u_q = *reinterpret_cast(&q); + uint32_t u_zp = *reinterpret_cast(&zp); + uint32_t u_zp1 = u_zp + 1; + uint32_t repeated_zp = 0x01010101 * u_zp; + + uint32_t q0, s; + q0 = (u_q & 0x0F0F0F0F) | 0x70707070; + s = (q0 + repeated_zp) & 0x80808080; + uint32_t Out1 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s; + + u_q >>= 4; + q0 = (u_q & 0x0F0F0F0F) | 0x70707070; + s = (q0 + repeated_zp) & 0x80808080; + uint32_t Out2 = (q0 + (s >> 7) * u_zp1) & 0x0F0F0F0F | s; + + frag_b[0] = *reinterpret_cast(&Out1); + frag_b[1] = *reinterpret_cast(&Out2); +} + +#endif + +} // namespace MARLIN_NAMESPACE_NAME diff --git a/src/infiniop/ops/awq_marlin_gemm/marlin/kernel.h b/src/infiniop/ops/awq_marlin_gemm/marlin/kernel.h new file mode 100644 index 000000000..b1fabe323 --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/marlin/kernel.h @@ -0,0 +1,43 @@ + +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin +#endif + +#include "../core/scalar_type.hpp" +#include "marlin.cuh" +#include "marlin_dtypes.cuh" + +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ b_bias_ptr, \ + const float *__restrict__ a_scales_ptr, \ + const int4 *__restrict__ scales_ptr, \ + const float *__restrict__ global_scale_ptr, \ + const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ + bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \ + int max_shared_mem + +namespace MARLIN_NAMESPACE_NAME { +template shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin(MARLIN_KERNEL_PARAMS); + +} diff --git a/src/infiniop/ops/awq_marlin_gemm/marlin/marlin.cuh b/src/infiniop/ops/awq_marlin_gemm/marlin/marlin.cuh new file mode 100644 index 000000000..3fbb4c463 --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/marlin/marlin.cuh @@ -0,0 +1,178 @@ +#pragma once + +#ifndef _marlin_cuh +#define _marlin_cuh + +#include +#include +#include +#include + +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin +#endif + +template +__device__ __forceinline__ uint32_t __cvta_generic_to_shared(T *ptr) { + size_t smem_addr; + asm volatile( + "cvta.to.shared.u64 %0, %1;" + : "=l"(smem_addr) + : "l"(ptr)); + return static_cast(smem_addr); +} + +namespace MARLIN_NAMESPACE_NAME { + +// Marlin params + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int default_threads = 256; + +static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; +static constexpr int max_thread_n = 256; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +// Repack params +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +// Helpers +template +struct Vec { + T elems[n]; + __device__ T &operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +__device__ inline void cp_async1_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async2_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { + reinterpret_cast(smem_ptr)[0] = reinterpret_cast(glob_ptr)[0]; +} + +__device__ inline void cp_async_fence() {} + +template +__device__ inline void cp_async_wait() {} + +#else + +__device__ inline void cp_async1_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 4; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async2_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 8; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4_ca_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4_pred(void *smem_ptr, const void *glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4(void *smem_ptr, const void *glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +#endif + +} // namespace MARLIN_NAMESPACE_NAME + +#endif \ No newline at end of file diff --git a/src/infiniop/ops/awq_marlin_gemm/marlin/marlin_dtypes.cuh b/src/infiniop/ops/awq_marlin_gemm/marlin/marlin_dtypes.cuh new file mode 100644 index 000000000..3ee06faef --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/marlin/marlin_dtypes.cuh @@ -0,0 +1,155 @@ + +#ifndef _data_types_cuh +#define _data_types_cuh +#include "../core/scalar_type.hpp" +#include "marlin.cuh" +#include +#include +#include + +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin +#endif + +namespace MARLIN_NAMESPACE_NAME { + +template +class MarlinScalarType { +}; + +template <> +class MarlinScalarType { +public: + using scalar_t = half; + using scalar_t2 = half2; + using scalar_t4 = half2; + using scalar_32bit_t = half2; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + using FragS0 = Vec<__nv_fp8x2_e4m3, 1>; + using FragZP = Vec; + + static __device__ float inline num2float(const half x) { + return __half2float(x); + } + + static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); + } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ half inline float2num(const float x) { + return __float2half(x); + } + + static __host__ __device__ float2 inline num22float2(const half2 x) { + return __half22float2(x); + } +}; + +template <> +class MarlinScalarType { +public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; + using scalar_t4 = nv_bfloat162; + using scalar_32bit_t = nv_bfloat162; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + using FragS0 = Vec<__nv_fp8x2_e4m3, 1>; + using FragZP = Vec; + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, + const nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } + + static __host__ __device__ float2 inline num22float2(const nv_bfloat162 x) { + return __bfloat1622float2(x); + } +#endif +}; + +template <> +class MarlinScalarType { +public: + using scalar_t = __nv_fp8_e4m3; + using scalar_t2 = __nv_fp8x2_e4m3; + using scalar_t4 = __nv_fp8x4_e4m3; + using scalar_32bit_t = __nv_fp8x4_e4m3; + + using FragA = Vec<__nv_fp8x4_e4m3, 4>; + using FragB = Vec<__nv_fp8x4_e4m3, 2>; + using FragC = Vec; + using FragZP = Vec<__nv_fp8x2_e4m3, 4>; + + static __host__ __device__ + float2 inline num22float2(const __nv_fp8x2_e4m3 x) { + return (float2)x; + } +}; + +template <> +class MarlinScalarType { +public: + using scalar_t = int8_t; + using scalar_t2 = int16_t; + using scalar_t4 = int32_t; + using scalar_32bit_t = int32_t; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragZP = Vec; +}; + +template +class MarlinScalarType2 { +}; + +template <> +class MarlinScalarType2 : public MarlinScalarType { +}; + +template <> +class MarlinScalarType2 + : public MarlinScalarType { +}; + +template <> +class MarlinScalarType2<__nv_fp8_e4m3> + : public MarlinScalarType { +}; + +template <> +class MarlinScalarType2 : public MarlinScalarType { +}; + +} // namespace MARLIN_NAMESPACE_NAME + +#endif diff --git a/src/infiniop/ops/awq_marlin_gemm/marlin/marlin_mma.h b/src/infiniop/ops/awq_marlin_gemm/marlin/marlin_mma.h new file mode 100644 index 000000000..bc905e71f --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/marlin/marlin_mma.h @@ -0,0 +1,267 @@ + +#include "marlin_dtypes.cuh" + +namespace MARLIN_NAMESPACE_NAME { + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma( + const typename MarlinScalarType::FragA &a_frag, + const typename MarlinScalarType::FragB &frag_b, + typename MarlinScalarType::FragC &frag_c, int idx = 0) { + const uint32_t *a = reinterpret_cast(&a_frag); + const uint32_t *b = reinterpret_cast(&frag_b); + using scalar_t = typename MarlinScalarType::scalar_t; + if constexpr (!std::is_same::value || k_size != 16) { + static_assert(!use_fp16_accum); + } + + if constexpr (k_size == 16) { + if constexpr (std::is_same::value && !use_fp16_accum) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + float *c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(b[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[2]), "r"(a[3]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); +#else + float *c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +#endif + } else if constexpr (std::is_same::value && use_fp16_accum) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + uint32_t *c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(a[2]), "r"(a[3]), "r"(b[1]), "r"(c[0]), "r"(c[1])); +#else + uint32_t *c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1])); +#endif + } else if constexpr (std::is_same::value) { + float *c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float *c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "f"(c[0]), + "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t *c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[idx * 2]), "r"(a[idx * 2 + 1]), "r"(b[idx]), "r"(c[0]), + "r"(c[1]), "r"(c[2]), "r"(c[3])); + } + } else if (k_size == 32) { + if constexpr (std::is_same::value) { + float *c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t *c = reinterpret_cast(&frag_c); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(a[0]), "r"(b[0]), "r"(c[0]), "r"(c[1])); + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(c[2]), "=r"(c[3]) + : "r"(a[1]), "r"(b[0]), "r"(c[2]), "r"(c[3])); + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(a[2]), "r"(b[1]), "r"(c[0]), "r"(c[1])); + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(c[2]), "=r"(c[3]) + : "r"(a[3]), "r"(b[1]), "r"(c[2]), "r"(c[3])); +#else + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +#endif + } + } +} + +template +__device__ inline void mma_trans( + const typename MarlinScalarType::FragA &a_frag, + const typename MarlinScalarType::FragB &frag_b, + const typename MarlinScalarType::FragB &frag_b2, + typename MarlinScalarType::FragC &frag_c) { + const uint32_t *a = reinterpret_cast(&a_frag); + const uint32_t *b = reinterpret_cast(&frag_b); + const uint32_t *b2 = reinterpret_cast(&frag_b2); + float *c = reinterpret_cast(&frag_c); + using scalar_t = typename MarlinScalarType::scalar_t; + if constexpr (!std::is_same::value || k_size != 16) { + static_assert(!use_fp16_accum); + } + + if constexpr (k_size == 16) { + if constexpr (std::is_same::value && !use_fp16_accum) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + float *c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[1]), "r"(b2[1]), "r"(a[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); +#else + float *c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +#endif + } else if constexpr (std::is_same::value && use_fp16_accum) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + uint32_t *c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(b[1]), "r"(b2[1]), "r"(a[1]), "r"(c[0]), "r"(c[1])); +#else + uint32_t *c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "r"(c[0]), "r"(c[1])); +#endif + } else if constexpr (std::is_same::value) { + float *c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + float *c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t *c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), + "r"(c[3])); + } + } else { + if constexpr (std::is_same::value) { + float *c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + int32_t *c = reinterpret_cast(&frag_c); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(b[0]), "r"(a[0]), "r"(c[0]), "r"(c[1])); + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(c[2]), "=r"(c[3]) + : "r"(b2[1]), "r"(a[0]), "r"(c[2]), "r"(c[3])); + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(b[0]), "r"(a[1]), "r"(c[0]), "r"(c[1])); + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1}, {%2}, {%3}, {%4,%5};\n" + : "=r"(c[2]), "=r"(c[3]) + : "r"(b2[1]), "r"(a[1]), "r"(c[2]), "r"(c[3])); +#else + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +#endif + } + } +} + +} // namespace MARLIN_NAMESPACE_NAME \ No newline at end of file diff --git a/src/infiniop/ops/awq_marlin_gemm/marlin/marlin_template.h b/src/infiniop/ops/awq_marlin_gemm/marlin/marlin_template.h new file mode 100644 index 000000000..d705af8fb --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/marlin/marlin_template.h @@ -0,0 +1,1988 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin +#endif + +#include "../core/scalar_type.hpp" +#include "dequant.h" +#include "marlin.cuh" +#include "marlin_dtypes.cuh" +#include "marlin_mma.h" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace MARLIN_NAMESPACE_NAME { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks, // extra global storage for barrier synchronization + bool use_fp32_reduce // whether to use fp32 global reduce +) { +} + +} // namespace marlin + +#else + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm(typename MarlinScalarType::FragA &frag_a, + const void *smem_ptr) { + uint32_t *a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (count == 4) { + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } else if constexpr (count == 2) { + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) + : "r"(smem)); + } else if constexpr (count == 1) { + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" + : "=r"(a[0]) + : "r"(smem)); + } else { + static_assert(count == 1 || count == 2 || count == 4, "invalid count"); + } +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void scale(typename MarlinScalarType::FragB &frag_b, + typename MarlinScalarType::FragS &frag_s, + int i) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s = MarlinScalarType::num2num2( + reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +template +__device__ inline void scale_and_sub( + typename MarlinScalarType::FragB &frag_b, + typename MarlinScalarType::scalar_t s, + typename MarlinScalarType::scalar_t zp) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 s2 = MarlinScalarType::num2num2(s); + scalar_t2 zp2 = MarlinScalarType::num2num2(zp); + frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); + frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); +} + +template +__device__ inline void sub_zp( + typename MarlinScalarType::FragB &frag_b, + typename MarlinScalarType::scalar_t2 &frag_zp, int i) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + scalar_t2 zp = MarlinScalarType::num2num2( + reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4( + typename MarlinScalarType::FragB &frag_b, + typename MarlinScalarType::FragS &frag_s_1, + typename MarlinScalarType::FragS &frag_s_2, + typename MarlinScalarType::FragS &frag_s_3, + typename MarlinScalarType::FragS &frag_s_4, int i) { + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float( + float *c, typename MarlinScalarType::FragS &s) { + using scalar_t = typename MarlinScalarType::scalar_t; + scalar_t *s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], MarlinScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], MarlinScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int *lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do { + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + } while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int *lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// Wait until value of lock to be negative, and then add 1 +__device__ inline void wait_negative_and_add(int *lock) { + if (threadIdx.x == 0) { + int state = 0; + do { + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + } while (state >= 0); + atomicAdd(lock, 1); + } + __syncthreads(); +} + +template shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4 *__restrict__ A0, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C0, // fp16 output buffer of shape mxn + int4 *__restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4 *__restrict__ b_bias_ptr, + // float scales of input matrix, only used when is_a_8bit == true. + // shape (m,) + const float *__restrict__ a_scales_ptr, + // fp16 quantization scales. shape (k/groupsize, n) + const int4 *__restrict__ scales_ptr, + // float global scale (for nvfp4// only) + const float *__restrict__ global_scale_ptr, + // 4bit packed zero-points of shape + // (k/groupsize, n/pack_factor) + const int4 *__restrict__ zp_ptr, + // int32 group indices of shape k + const int *__restrict__ g_idx, + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int lda, // A.stride(0), equal to prob_k is A is contiguous + int *locks, // extra global storage for barrier synchronization + bool has_bias, + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 890 + // FP8 computation is only supported for Ada Lovelace or newer architectures. + if constexpr (a_type_id == vllm::kFE4M3fn.id()) { + return; + } +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + // Turing TensorCore only supports fp16 and int8 + if constexpr (a_type_id != vllm::kFloat16.id() && a_type_id != vllm::kS8.id()) { + return; + } +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + constexpr auto num_bits = vllm::ScalarType::from_id(b_type_id).size_bits(); + // Disable use_fp16_accum for NVFP4 and cases when group_size == -1 && + // num_bits == 4 + constexpr bool use_fp16_accum = a_type_id == vllm::kFloat16.id() && (!(b_type_id == vllm::kFE2M1f.id() && s_type_id == vllm::kFE4M3fn.id()) && !(group_blocks == -1 && num_bits == 4)); +#else + constexpr bool use_fp16_accum = false; +#endif + using Adtype = MarlinScalarType; + using Cdtype = MarlinScalarType; + const int4 *A = A0; + int4 *C = C0; + + using scalar_t = typename MarlinScalarType::scalar_t; + using scalar_t2 = typename MarlinScalarType::scalar_t2; + using scalar_32bit_t = typename MarlinScalarType::scalar_32bit_t; + + using c_scalar_t = typename MarlinScalarType::scalar_t; + using c_scalar_t2 = typename MarlinScalarType::scalar_t2; + + using FragA = typename MarlinScalarType::FragA; + using FragB = typename MarlinScalarType::FragB; + using FragC = typename MarlinScalarType::FragC; + using FragS = typename MarlinScalarType::FragS; + using FragZP = typename MarlinScalarType::FragZP; + + static constexpr auto a_type = vllm::ScalarType::from_id(a_type_id); + static constexpr auto b_type = vllm::ScalarType::from_id(b_type_id); + static constexpr auto c_type = vllm::ScalarType::from_id(c_type_id); + static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id); + if constexpr (b_type == vllm::kFE2M1f) { + static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || s_type == vllm::kFE8M0fnu && group_blocks == 2); + } else if constexpr (std::is_same::value) { + static_assert(s_type == vllm::kBFloat16); + } else if constexpr (std::is_same::value) { + static_assert(s_type == vllm::kFloat16); + } + + constexpr bool is_a_8bit = a_type.size_bits() == 8; + if constexpr (!is_a_8bit) { + static_assert(std::is_same::value); + } + constexpr bool has_zp = b_type == vllm::kU4 || b_type == vllm::kU8; + constexpr bool is_int_type = b_type == vllm::kU4 || b_type == vllm::kU8 || b_type == vllm::kS4 || b_type == vllm::kS8 || b_type == vllm::kU4B8 || b_type == vllm::kU8B128; + // see comments of dequant.h for more details + constexpr bool dequant_skip_flop = is_a_8bit || b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(b_type == vllm::kU8); + + float global_scale_f32 = 1.0f; + + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + global_scale_f32 = global_scale_ptr[0]; + } + + constexpr bool has_act_order = group_blocks == 0; + constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + + extern __shared__ int4 sh[]; + float *sh_a_s = reinterpret_cast(sh); + int4 *sh_new = sh + (is_a_8bit ? (4 * thread_m_blocks) : 0); + constexpr int pack_factor = 32 / b_type.size_bits(); + static_assert(thread_m_blocks == 1 || !m_block_size_8); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > m_block_size) { + parallel = prob_m / m_block_size; + prob_m = m_block_size; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + + int global_mn_tiles = parallel * n_tiles; + int part2_mn_tiles = global_mn_tiles; + int part1_mn_iters = 0; + bool in_part2 = false; + + if (global_mn_tiles > gridDim.x) { + part2_mn_tiles = global_mn_tiles % gridDim.x; + if (part2_mn_tiles * 3 <= gridDim.x) { + part2_mn_tiles += gridDim.x; + } + part1_mn_iters = (global_mn_tiles - part2_mn_tiles) / gridDim.x; + } + + int iters = div_ceil(k_tiles * part2_mn_tiles, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = 0; + int slice_col_par = blockIdx.x; + int slice_col; + int slice_iters = k_tiles; // number of threadblock tiles in the current slice + // total number of active threadblocks in the current slice + int slice_count = 1; + // index of threadblock in current slice; numbered bottom to top + int slice_idx = 0; + + int par_id = 0; + int locks_off = 0; + + if (part2_mn_tiles >= gridDim.x) { + // when part2_mn_tiles >= sms + // then there are at most $sms$ conflict tile blocks + locks_off = blockIdx.x; + } else { + locks_off = (iters * blockIdx.x) / k_tiles - 1; + } + + // Compute all information about the current slice which is required for + // synchronization. + bool first_init = true; + auto init_part2_slice = [&]() { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= part2_mn_tiles) { + slice_iters = 0; + } + if (slice_iters == 0) { + return; + } + if (slice_row + slice_iters > k_tiles) { + slice_iters = k_tiles - slice_row; + } + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) { + slice_count++; + } + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) { + slice_idx = slice_count - 1; + } else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) { + slice_idx--; + } + } + } + if (part2_mn_tiles >= gridDim.x) { + if (slice_count > 1 && slice_idx == slice_count - 1) { + locks_off++; + } + } else { + locks_off++; + } + + if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) { + constexpr int threads_per_m = 16 * thread_n_blocks / 8; + int m_per_thread = div_ceil(thread_m_blocks * 16, threads / threads_per_m); + if (m_block_size_8) { + m_per_thread = div_ceil(8, threads / threads_per_m); + } + for (int i = 0; i < m_per_thread; i++) { + int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; + if (row < prob_m) { + int col = slice_col * 16 * thread_n_blocks / 8 + threadIdx.x % threads_per_m; + C[row * prob_n / 8 + col] = {0, 0, 0, 0}; + } + } + // After write zero to output, write a negative value to lock. + // Every SM that processes the same slice would wait for + // the negative value, and then atomicAdd 1 to it. + // After all SMs are processed, the lock value would back to 0 again. + __syncthreads(); + if (threadIdx.x == 0) { + locks[locks_off] = 1 - slice_count; + } + } + + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * lda / (is_a_8bit ? 16 : 8); + C += 16 * thread_m_blocks * prob_n / 8; + slice_col = 0; + par_id++; + } + if (is_a_8bit && (first_init || slice_col == 0)) { + __syncthreads(); + int a_s_gl_rd = par_id * 16 * thread_m_blocks + threadIdx.x; + cp_async1_ca_pred(&sh_a_s[threadIdx.x], &a_scales_ptr[a_s_gl_rd], + threadIdx.x < prob_m); + } + }; + + auto init_part1_slice = [&]() { + if (part1_mn_iters) { + part1_mn_iters--; + par_id = slice_col_par / n_tiles; + slice_col = slice_col_par % n_tiles; + slice_iters = k_tiles; + A = A0 + 16 * thread_m_blocks / (is_a_8bit ? 16 : 8) * par_id * lda; + C = C0 + 16 * thread_m_blocks / 8 * par_id * prob_n; + if (is_a_8bit) { + __syncthreads(); + int a_s_gl_rd = par_id * 16 * thread_m_blocks + threadIdx.x; + cp_async1_ca_pred(&sh_a_s[threadIdx.x], &a_scales_ptr[a_s_gl_rd], + threadIdx.x < prob_m); + } + } + }; + + auto init_slice = [&]() { + if (!in_part2 && !part1_mn_iters) { + in_part2 = true; + slice_col_par = (iters * blockIdx.x) / k_tiles; + slice_row = (iters * blockIdx.x) % k_tiles; + slice_col = (slice_col_par + global_mn_tiles - part2_mn_tiles) % n_tiles; + par_id = (slice_col_par + global_mn_tiles - part2_mn_tiles) / n_tiles; + A = A0 + 16 * thread_m_blocks / (is_a_8bit ? 16 : 8) * par_id * lda; + C = C0 + 16 * thread_m_blocks / 8 * par_id * prob_n; + } + if (!in_part2) { + init_part1_slice(); + } else { + init_part2_slice(); + first_init = false; + } + }; + + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = lda / (is_a_8bit ? 16 : 8); + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8); + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / (is_a_8bit ? 16 : 8); + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * m_block_size; + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * (is_a_8bit ? 2 : 4)); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / (is_a_8bit ? 2 : 4); + constexpr int b_thread_vecs = b_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks / (is_a_8bit ? 2 : 1); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks / (is_a_8bit ? 2 : 1); + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / (b_type == vllm::kFE2M1f ? 16 : 8); + constexpr int s_sh_stride = 16 * thread_n_blocks / (b_type == vllm::kFE2M1f ? 16 : 8); + constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + constexpr int act_s_max_num_groups = 32; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + + constexpr int tb_n_warps = thread_n_blocks / (is_a_8bit ? 2 : 4); + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_zp_float + ? 16 * thread_n_blocks / 8 + : ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); + a_sh_rd += 2 * ((threadIdx.x / 32) / tb_n_warps) * b_sh_wr_iters; + + int b_gl_rd; + if (threads <= b_sh_stride) { + b_gl_rd = threadIdx.x; + } else { + b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + } + + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + auto b_sh_rd = threadIdx.x * b_thread_vecs; + b_sh_rd += b_sh_rd / b_sh_stride * (b_sh_stride * (b_sh_wr_iters - 1)); + + // For act_order + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + threadIdx.x / s_sh_stride) + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; + } + } + auto s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stage; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else if constexpr (group_blocks >= thread_k_blocks) { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + threadIdx.x / zp_sh_stride) + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; + } + } + auto zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = zp_sh_stage > 0 && threadIdx.x < zp_sh_stage; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (is_a_8bit) { + s_sh_rd = 4 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 4); + } else if constexpr (group_blocks != -1) { + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4; + } else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) { + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8; + } else { + s_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4; + } + + int bias_sh_rd; + if constexpr (m_block_size_8) { + bias_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 8; + } else { + bias_sh_rd = (is_a_8bit ? 4 : 8) * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) % 4; + } + + int bias_sh_wr = threadIdx.x; + int bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + if constexpr (is_zp_float) { + if constexpr (group_blocks != -1) { + zp_sh_rd = 8 * ((threadIdx.x / 32) % tb_n_warps) + (threadIdx.x % 32) / 4; + } + } else if (is_a_8bit) { + zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % tb_n_warps / 2) + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } else { + zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % tb_n_warps) + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8); + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + } + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) { + a_sh_rd_trans[i][j] = transform_a(2 * i + a_sh_rd_delta_i * j + a_sh_rd); + } + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + + // Shared memory storage for global fetch pipelines. + constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; + constexpr int sh_b_size = stages * b_sh_stage; + int4 *sh_b = sh_new; + int4 *sh_red = sh_new; + constexpr int sh_size_b_red_min = (sh_red_size < sh_b_size ? sh_red_size : sh_b_size); + constexpr int sh_size_b_red_max = (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + constexpr int sh_bias_size = (thread_n_blocks * 16 / 8); + constexpr int sh_b_red_bias_size = sh_size_b_red_max > (sh_size_b_red_min + sh_bias_size) + ? sh_size_b_red_max + : (sh_size_b_red_min + sh_bias_size); + + int4 *sh_bias = sh_new + sh_size_b_red_min; + int4 *sh_g_idx = sh_new + sh_b_red_bias_size; + int4 *sh_zp = sh_g_idx + (stages * g_idx_stage); + constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) + : (stages * s_sh_stage); + int4 *sh_s = sh_zp + (stages * zp_sh_stage); + int4 *sh_a = sh_s + sh_s_size; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][is_a_8bit ? 2 : 4][2]; + FragC frag_c_tmp[thread_m_blocks][is_a_8bit ? 2 : 4][2]; + FragS frag_s[2][4]; // No act-order + FragS frag_bias[2][4]; + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + + if constexpr (is_a_8bit) { +#pragma unroll + for (int j = 0; j < 2; j++) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int g = 0; g < 4; g++) { + frag_c_tmp[i][j][0][g] = 0.0f; + } + +#pragma unroll + for (int g = 0; g < 4; g++) { + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } + } + + // Zero accumulators. + auto zero_accums = [&]() { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + + auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups > act_s_max_num_groups) { + sh_num_groups = act_s_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < (b_sh_wr_iters * b_thread_vecs); i++) { + constexpr int count = div_ceil(b_sh_stride, threads); + int b_gl_idx = b_gl_rd + (i % count) * threads + b_gl_stride * (i / count) * div_ceil(threads, b_sh_stride); + + cp_async4(&sh_b_stage[threads * i + threadIdx.x], &B[b_gl_idx]); + } + + b_gl_rd += b_gl_rd_delta_o; + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const *cur_g_idx_stage_ptr = reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + + // Only fetch scales if this tile starts a new group + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta * s_tb_groups; + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + // Only fetch zero points if this tile starts a new group + if (pipe % div_ceil(group_blocks, thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta * zp_tb_groups; + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_col_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + auto fetch_col_scale_to_shared = [&]() { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm( + frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + +#pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_stride * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + return; + } + + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + using IT1 = typename std::conditional_t; + using IT0 = typename std::conditional_t; + constexpr int group_blocks2 = div_ceil(group_blocks, is_a_8bit ? 2 : 1); + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0 && dequant_skip_flop) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } else if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0) { + if (k % b_sh_wr_iters == 0) { + int4 *sh_s_stage = sh_s + s_sh_stage * (g * (pipe / g)); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + } + } + } else if (group_blocks2 < b_sh_wr_iters || k % b_sh_wr_iters == 0) { + auto warp_id = threadIdx.x / 32; + int warp_row = warp_id / tb_n_warps; + + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; + int cur_group_id = k_blocks / group_blocks2; + + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (b_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } + } else if (group_blocks >= b_sh_wr_iters) { + if constexpr (b_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + } else { + reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + } + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k % b_sh_wr_iters; + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + auto warp_id = threadIdx.x / 32; + int warp_row = warp_id / tb_n_warps; + int warp_col = warp_id % tb_n_warps; + + cur_k += warp_row * 16 * b_sh_wr_iters; + + auto th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + +#pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp && !is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0 || is_a_8bit) { +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + } + } else if constexpr (group_blocks >= thread_k_blocks) { + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0 && k % b_sh_wr_iters == 0 || is_a_8bit) { + int4 *sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g)); +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } else { + auto warp_id = threadIdx.x / 32; + + int warp_row = warp_id / tb_n_warps; + + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; + int cur_group_id = k_blocks / div_ceil(group_blocks, is_a_8bit ? 2 : 1); + + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + + else if constexpr (has_zp && is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + constexpr int g = group_blocks / thread_k_blocks; + if (pipe % g == 0 && k % b_sh_wr_iters == 0) { + int4 *sh_zp_stage = sh_zp + zp_sh_stage * (g * (pipe / g)); + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + } + } else if (group_blocks < b_sh_wr_iters || k % b_sh_wr_iters == 0) { + auto warp_id = threadIdx.x / 32; + + int warp_row = warp_id / tb_n_warps; + int k_blocks = b_sh_wr_iters * warp_row + k % b_sh_wr_iters; + int cur_group_id = k_blocks / group_blocks; + + int4 *sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } + } + }; + + auto dequant_data = [&](int q, scalar_32bit_t *frag_b_ptr, int zp = 0) { + if constexpr (a_type.size_bits() != b_type.size_bits()) { + if constexpr (is_a_8bit && has_zp) { + sub_zp_and_dequant( + q, frag_b_ptr, zp); + } else { + dequant(q, frag_b_ptr); + } + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + bool is_first_matmul_in_slice = true; + auto matmul = [&](int k, int pipe) { + if (is_a_8bit) { + return; + } + int k2 = k % 2; + constexpr int g = group_blocks > 0 ? div_ceil(group_blocks, thread_k_blocks) : 1; + const bool is_new_zp = (group_blocks == 0) || ((group_blocks > 0) && (group_blocks < b_sh_wr_iters || k == 0)) && (pipe % g == 0) || (group_blocks == -1 && is_first_matmul_in_slice); + if constexpr (has_zp && !is_zp_float) { + if (is_new_zp) { + if constexpr (group_blocks == -1) { + is_first_matmul_in_slice = false; + } + int zp_quant_0, zp_quant_1; + + if constexpr (b_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(b_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = frag_qzp[k2][1]; + } + + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, + reinterpret_cast(&frag_zp) + 2); + } + } + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { + if (is_new_zp) { + reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; + } + } + + if constexpr (b_type == vllm::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales( + s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } + +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (b_type_id == vllm::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (b_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(b_type.size_bits() == 8); + int *frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + + if constexpr (dequant_skip_flop && has_zp && !is_zp_float && !is_a_8bit) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } + + // Apply scale to frag_b0 + if constexpr (has_act_order && !is_a_8bit) { + static_assert(group_blocks != -1); + scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && group_blocks == -1 && !is_a_8bit) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 s2 = Adtype::nums2num2( + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); + if (is_new_zp) + frag_zp[j] = __hmul2(frag_zp[j], s2); + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1 && !is_a_8bit) { + if (is_new_zp) + frag_zp[j] = __hmul2(frag_zp[j], + *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } else if constexpr (group_blocks != -1 && !is_a_8bit) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + if constexpr (m_block_size_8) { + mma_trans(frag_a[k2][i], frag_b0, frag_b1, + frag_c[i][j][0]); + } else { + mma(frag_a[k2][i], frag_b0, + frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, + frag_c[i][j][1]); + } + } + } + }; + + auto matmul_a8 = [&](int k) { + int k2 = k % 2; +#pragma unroll + for (int j = 0; j < 2; j++) { + FragB frag_b[2]; + + if (is_a_8bit && b_type.size_bits() == 4 && !has_zp) { + dequant_data(frag_b_quant[k2][0][j * 2], + reinterpret_cast(&frag_b)); + dequant_data(frag_b_quant[k2][0][j * 2 + 1], + reinterpret_cast(&frag_b) + 2); + } else if (is_a_8bit && b_type.size_bits() == 4 && has_zp) { + int off = (threadIdx.x / 32) % 2 * 2 + j; + int zp = (frag_qzp[k2][0] >> (off * 8)) & 0xF; + dequant_data(frag_b_quant[k2][0][j * 2], + reinterpret_cast(&frag_b), zp); + zp = (frag_qzp[k2][0] >> (off * 8 + 4)) & 0xF; + dequant_data(frag_b_quant[k2][0][j * 2 + 1], + reinterpret_cast(&frag_b) + 2, zp); + } else { + reinterpret_cast(&frag_b)[0] = reinterpret_cast(&frag_b_quant[k2][j])[0]; + reinterpret_cast(&frag_b)[1] = reinterpret_cast(&frag_b_quant[k2][j])[1]; + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma( + frag_a[k2][i], frag_b[0], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][0]); + mma( + frag_a[k2][i], frag_b[1], + (group_blocks == -1 ? frag_c : frag_c_tmp)[i][j][1]); + } + + if constexpr (group_blocks != -1) { + if (group_blocks == 2 || k == 1) { + if constexpr (a_type == vllm::kS8) { + int2 s_vals[2]; + s_vals[0] = { + (int)reinterpret_cast(&frag_s[k2][j * 2][0])[0], + (int)reinterpret_cast(&frag_s[k2][j * 2][0])[1]}; + s_vals[1] = { + (int)reinterpret_cast(&frag_s[k2][j * 2 + 1][0])[0], + (int)reinterpret_cast(&frag_s[k2][j * 2 + 1][0])[1]}; + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int g = 0; g < 4; g++) { + int scale = reinterpret_cast(&s_vals[0])[g % 2]; + *reinterpret_cast(&frag_c[i][j][0][g]) += *reinterpret_cast(&frag_c_tmp[i][j][0][g]) * scale; + frag_c_tmp[i][j][0][g] = 0.0f; + } + +#pragma unroll + for (int g = 0; g < 4; g++) { + int scale = reinterpret_cast(&s_vals[1])[g % 2]; + *reinterpret_cast(&frag_c[i][j][1][g]) += *reinterpret_cast(&frag_c_tmp[i][j][1][g]) * scale; + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } else { + float2 s_vals[2]; + if constexpr (s_type_id != vllm::kFE8M0fnu.id()) { + static_assert(a_type.size_bits() == 16 || s_type.size_bits() == 16); + s_vals[0] = Cdtype::num22float2(frag_s[k2][j * 2][0]); + s_vals[1] = Cdtype::num22float2(frag_s[k2][j * 2 + 1][0]); + } else { + int32_t *s_vals_int = reinterpret_cast(&s_vals[0]); + int32_t s_vals_e8m0 = *reinterpret_cast(&frag_s[k2][j][0]); + + s_vals_int[0] = (s_vals_e8m0 & 0xFF) << 23; + s_vals_int[1] = (s_vals_e8m0 & 0xFF00) << 15; + s_vals_int[2] = (s_vals_e8m0 & 0xFF0000) << 7; + s_vals_int[3] = (s_vals_e8m0 & 0xFF000000) >> 1; + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&s_vals[0])[g % 2]; + frag_c[i][j][0][g] += frag_c_tmp[i][j][0][g] * scale; + frag_c_tmp[i][j][0][g] = 0.0f; + } + +#pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&s_vals[1])[g % 2]; + frag_c[i][j][1][g] += frag_c_tmp[i][j][1][g] * scale; + frag_c_tmp[i][j][1][g] = 0.0f; + } + } + } + } + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + auto red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * (is_a_8bit ? 2 : 4) * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { +#pragma unroll + for (int j = 0; j < (is_a_8bit ? 2 : 4) * 2; + j += (m_block_size_8 ? 2 : 1)) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float *c_rd = reinterpret_cast( + &sh_red[red_sh_delta * j + red_sh_rd]); + float *c_wr = reinterpret_cast(&sh_red[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast( + frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j][k] + += c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = reinterpret_cast( + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { +#pragma unroll + for (int i = 0; i < (is_a_8bit ? 2 : 4) * 2; + i += (m_block_size_8 ? 2 : 1)) { + float *c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast( + frag_c)[(is_a_8bit ? 2 : 4) * 2 * m_block + i][j] + += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * tb_n_warps; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride * (is_a_8bit ? 2 : 1); + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr; + if constexpr (m_block_size_8) { + c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } else { + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) * (is_a_8bit ? 2 : 1) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col * (is_a_8bit ? 2 : 1); + } + constexpr int c_sh_wr_delta = active_threads; + auto c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { +// Interestingly, doing direct global accesses here really seems to mess up +// the compiler and lead to slowdowns, hence we also use async-copies even +// though these fetches are not actually asynchronous. +#pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + if constexpr (m_block_size_8) { + cp_async4_pred(&sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i], + (threadIdx.x % 4) * 2 + i < prob_m); + } else if constexpr (is_a_8bit) { + int2 *sh_red_int2 = reinterpret_cast(sh_red); + int2 *c_int2 = reinterpret_cast(C); + cp_async2_ca_pred( + &sh_red_int2[c_sh_wr + c_sh_wr_delta * i], + &c_int2[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } else { + cp_async4_pred( + &sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + } + cp_async_fence(); + cp_async_wait<0>(); + } + +#pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + bool mask = (!m_block_size_8) && (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) || (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m); + if (mask) { + if (!first) { + c_scalar_t *c_red_f16; + if constexpr (is_a_8bit) { + int2 tmp = reinterpret_cast(sh_red)[c_sh_wr + i * c_sh_wr_delta]; + c_red_f16 = reinterpret_cast(&tmp); + } else { + int4 tmp = sh_red[c_sh_wr + i * c_sh_wr_delta]; + c_red_f16 = reinterpret_cast(&tmp); + } +#pragma unroll + for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast( + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] + += Cdtype::num2float(c_red_f16[j]); + } + } + if (!last) { + c_scalar_t c_f16[is_a_8bit ? 4 : 8]; +#pragma unroll + for (int j = 0; j < 2 * (is_a_8bit ? 2 : 4); j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + c_f16[j] = Cdtype::float2num(reinterpret_cast( + &frag_c)[(is_a_8bit ? 2 : 4) * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + } + if constexpr (m_block_size_8) { + C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = *reinterpret_cast(c_f16); + } else if constexpr (is_a_8bit) { + int2 *c_int2 = reinterpret_cast(C); + c_int2[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = *reinterpret_cast(c_f16); + } else { + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = *reinterpret_cast(c_f16); + } + } + } + } + } + }; + + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * tb_n_warps; + bool is_th_active = threadIdx.x < active_threads; + + constexpr int num_floats = thread_m_blocks * (is_a_8bit ? 2 : 4) * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = locks_off * c_size; + + if (!is_th_active) { + return; + } + + if (!first) { + float *frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) { + sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float *sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); +#pragma unroll + for (int f = 0; f < 4; f++) { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } + } + + if (!last) { + int4 *frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) { + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&](bool last) { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr; + if constexpr (m_block_size_8) { + c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + (threadIdx.x % 32) / 4; + c_sh_wr += 64 * (threadIdx.x / 32); + } else { + c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += (is_a_8bit ? 16 : 32) * (threadIdx.x / 32); + } + + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS &s, FragS &b_bias) { + if constexpr (b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { + c0 *= global_scale_f32; + c1 *= global_scale_f32; + } + c_scalar_t2 res = Cdtype::nums2num2(Cdtype::float2num(c0), Cdtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && !is_a_8bit && b_type.size_bits() == 4 && (has_zp && dequant_skip_flop || !has_zp)) { + c_scalar_t2 tmp_scale = s[0]; + if constexpr (m_block_size_8) { + tmp_scale = Cdtype::num2num2( + reinterpret_cast(&s[0])[(threadIdx.x % 8) / 4]); + } + res = __hmul2(res, tmp_scale); + } + if (has_bias && last) { + c_scalar_t2 tmp_bias = b_bias[0]; + if constexpr (m_block_size_8) { + tmp_bias = Cdtype::num2num2( + reinterpret_cast(&b_bias[0])[(threadIdx.x % 8) / 4]); + } + res = __hadd2(res, tmp_bias); + } + + if constexpr (m_block_size_8) { + ((c_scalar_t *)sh_red)[idx] = res.x; + ((c_scalar_t *)sh_red)[idx + 8 * c_sh_stride] = res.y; + } else { + ((c_scalar_t2 *)sh_red)[idx] = res; + } + }; + + if (threadIdx.x / 32 < tb_n_warps) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < (is_a_8bit ? 2 : 4); j++) { + if constexpr (m_block_size_8) { + int wr = c_sh_wr + 16 * j; + write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], + frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); + write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], + frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); + } else { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0], + frag_bias[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1], + frag_bias[j / 2][2 * (j % 2) + 1]); + } + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + if (use_atomic_add && slice_count > 1) { + c_scalar_t2 *C_half2 = reinterpret_cast(&C[c_gl_wr]); + c_scalar_t2 *sh_red_half2 = reinterpret_cast(&sh_red[c_sh_rd]); +#pragma unroll + for (int a = 0; a < 4; a++) { + atomicAdd(&C_half2[a], sh_red_half2[a]); + } + } else { + C[c_gl_wr] = sh_red[c_sh_rd]; + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + __syncthreads(); + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + +#pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], + g_idx[last_g_idx]); + } + + if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + if (i == 0) { + fetch_col_zp_to_shared(); + if constexpr (!dequant_skip_flop) { + fetch_col_scale_to_shared(); + } + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + if constexpr (has_act_order) { + slice_k_start_shared_fetch += tb_k * (stages - 1); + } + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + +#pragma unroll + for (int pipe = 0; pipe < stages;) { +#pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + + if constexpr (!is_a_8bit) { + matmul(k, pipe - (k >= b_sh_wr_iters - 2 ? 1 : 0)); + } else { + static_assert(group_blocks != 0 && group_blocks != 1); + matmul_a8(k); + } + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + + if constexpr (has_act_order) { + slice_k_start += tb_k * stages; + + if (slice_k_start < prob_k) { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, + last_group_id); + __syncthreads(); + } + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + // convert fp16 accum to fp32 for reduction + if constexpr (use_fp16_accum) { +#pragma unroll + for (int i = 0; i < (thread_m_blocks * (is_a_8bit ? 2 : 4) * 2); i++) { + float *frag_c_part_float = reinterpret_cast(frag_c) + i * 4; + scalar_t *frag_c_part_half = reinterpret_cast(frag_c_part_float); + +#pragma unroll + for (int i = 3; i >= 0; i--) { + frag_c_part_float[i] = Cdtype::num2float(frag_c_part_half[i]); + } + } + } + + if constexpr (is_a_8bit) { + float frag_a_s[2 * thread_m_blocks]; + + for (int i = 0; i < 2 * thread_m_blocks; i++) { + frag_a_s[i] = sh_a_s[i * 8 + (threadIdx.x % 32) / 4]; + } + +#pragma unroll + for (int j = 0; j < 2; j++) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int g = 0; g < 4; g++) { + float c_val = frag_c[i][j][0][g]; + + if constexpr (a_type == vllm::kS8) { + c_val = __int2float_rn(*reinterpret_cast(&c_val)); + } + float s_val = frag_a_s[i * 2 + g / 2]; + frag_c[i][j][0][g] = c_val * s_val; + } +#pragma unroll + for (int g = 0; g < 4; g++) { + float c_val = frag_c[i][j][1][g]; + + if constexpr (a_type == vllm::kS8) { + c_val = __int2float_rn(*reinterpret_cast(&c_val)); + } + float s_val = frag_a_s[i * 2 + g / 2]; + frag_c[i][j][1][g] = c_val * s_val; + } + } + } + } + + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { + if (b_type.size_bits() == 8 || (last || use_atomic_add) || is_a_8bit) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + + if (has_bias && last) { + __syncthreads(); + cp_async4_pred(&sh_bias[bias_sh_wr], &b_bias_ptr[bias_gl_rd], + threadIdx.x < 16 * thread_n_blocks / 8); + cp_async_fence(); + } + + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp || is_a_8bit)) { + if constexpr (is_a_8bit) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < tb_n_warps) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + } + } else if (b_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < tb_n_warps) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + if constexpr (m_block_size_8) { + int idx = (threadIdx.x / 4) % 2; + c_scalar_t2 *frag_s_half2 = reinterpret_cast(frag_s); +#pragma unroll + for (int i = 0; i < 8; i++) { + frag_s_half2[i] = Cdtype::num2num2( + reinterpret_cast(&frag_s_half2[i])[idx]); + } + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && is_a_8bit) { +#pragma unroll + for (int j = 0; j < 2; j++) { + float2 aa[2]; + aa[0] = Cdtype::num22float2(frag_s[0][j * 2][0]); + aa[1] = Cdtype::num22float2(frag_s[0][j * 2 + 1][0]); + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&aa[0])[g % 2]; + frag_c[i][j][0][g] *= scale; + } + +#pragma unroll + for (int g = 0; g < 4; g++) { + float scale = reinterpret_cast(&aa[1])[g % 2]; + frag_c[i][j][1][g] *= scale; + } + } + } + } else if (!has_act_order && group_blocks == -1 && b_type.size_bits() == 8 && (has_zp && dequant_skip_flop || !has_zp)) { + if (threadIdx.x / 32 < tb_n_warps) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + scale_float( + reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); + + if constexpr (!m_block_size_8) { + scale_float( + reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + } + + if (slice_count > 1 && !use_atomic_add) { + // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[locks_off], slice_idx); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[locks_off], last); + } + + if (has_bias && last) { + cp_async_wait<0>(); + __syncthreads(); + reinterpret_cast(&frag_bias)[0] = sh_bias[bias_sh_rd]; + if constexpr (!is_a_8bit) { + reinterpret_cast(&frag_bias)[1] = sh_bias[bias_sh_rd + 4]; + } + __syncthreads(); + } + + if (use_atomic_add && slice_count > 1 && slice_idx != 0) { + wait_negative_and_add(&locks[locks_off]); + } + if (last || use_atomic_add) { + // only the last block in a slice actually writes the result + write_result(last); + } + slice_row = 0; + if (!in_part2) { + slice_col_par += gridDim.x; + } else { + slice_col_par++; + slice_col++; + } + is_first_matmul_in_slice = true; + init_slice(); + + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col + b_gl_rd_delta_o * slice_row; + + bias_gl_rd = (thread_n_blocks * 16 / 8) * slice_col + threadIdx.x; + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + } else { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else if constexpr (group_blocks >= thread_k_blocks) { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + threadIdx.x / s_sh_stride) + s_sh_stride * slice_col + threadIdx.x % s_sh_stride; + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks + threadIdx.x / zp_sh_stride) + zp_sh_stride * slice_col + threadIdx.x % zp_sh_stride; + } + } + start_pipes(); + } + } + } +} + +} // namespace MARLIN_NAMESPACE_NAME + +#endif diff --git a/src/infiniop/ops/awq_marlin_gemm/nvidia/.gitignore b/src/infiniop/ops/awq_marlin_gemm/nvidia/.gitignore new file mode 100644 index 000000000..c52210efc --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/nvidia/.gitignore @@ -0,0 +1,4 @@ +sm*_kernel_*.cu +kernel_selector.h +kernel_*.cu + diff --git a/src/infiniop/ops/awq_marlin_gemm/nvidia/awq_marlin_gemm_nvidia.cu b/src/infiniop/ops/awq_marlin_gemm/nvidia/awq_marlin_gemm_nvidia.cu new file mode 100644 index 000000000..f682772ff --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/nvidia/awq_marlin_gemm_nvidia.cu @@ -0,0 +1,367 @@ +#if defined(ENABLE_NVIDIA_API) +#include "../../../devices/nvidia/nvidia_handle.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../core/utils.h" +#include "awq_marlin_gemm_nvidia.cuh" +#include "kernel.cuh" + +template +infiniStatus_t awq_marlin_gemm_kernel( + const void *a, + void *c, + const void *b_q_weight, + void *b_bias, + void *b_scales, + void *a_scales, + void *global_scale, + void *b_zeros, + void *g_idx, + void *perm, + int64_t b_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float, + int size_m, + int size_k, + int size_n, + int b_q_size_0, + int b_q_size_1, + int a_stride_0, + int b_zeros_size_1, + int num_groups, + void *total_buffer, + cudaStream_t stream) { + // scalar_t *a, Tdata *b_scales + vllm::ScalarTypeId a_type_id, c_type_id, s_type_id; + + if constexpr (std::is_same::value) { + a_type_id = vllm::kFloat16.id(); + c_type_id = vllm::kFloat16.id(); + } else if constexpr (std::is_same::value) { + a_type_id = vllm::kBFloat16.id(); + c_type_id = vllm::kBFloat16.id(); + } else { + // 此时c和b_scales类型相同 + if constexpr (std::is_same::value) { + c_type_id = vllm::kFloat16.id(); + } else if constexpr (std::is_same::value) { + c_type_id = vllm::kBFloat16.id(); + } else { + c_type_id = vllm::kBFloat16.id(); + host::RuntimeCheck(c != nullptr, "c must be passed for W4A8-FP4\n"); + } + if constexpr (std::is_same::value) { + a_type_id = vllm::kFE4M3fn.id(); + } else if constexpr (std::is_same::value) { + a_type_id = vllm::kS8.id(); + } else { + host::RuntimeCheck(false, "unsupported `a` scalar_type\n"); + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } + + s_type_id = c_type_id; + if (b_type_id == vllm::kFE2M1f.id()) { + if constexpr (std::is_same::value) { + s_type_id = vllm::kFE4M3fn.id(); + } else if constexpr (std::is_same::value) { + printf("b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu\n"); + s_type_id = vllm::kFE8M0fnu.id(); + } else { + host::RuntimeCheck(false, + "When b_type = float4_e2m1f, b_scale scalar type must be", + "float8_e4m3fn (for NVFP4) or float8_e8m0fnu (for MXFP4)."); + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } + + vllm::ScalarType a_type = vllm::ScalarType::from_id(a_type_id); + vllm::ScalarType b_type = vllm::ScalarType::from_id(b_type_id); + vllm::ScalarType c_type = vllm::ScalarType::from_id(c_type_id); + vllm::ScalarType s_type = vllm::ScalarType::from_id(s_type_id); + + int pack_factor = 32 / b_type.size_bits(); + + // Verify a = [size_m, size_k] + + // Verify b + host::RuntimeCheck( + size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + host::RuntimeCheck((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_size_0, + "Shape mismatch: b_q_weight.size(0) = ", b_q_size_0, + ", size_k = ", size_k, + ", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + host::RuntimeCheck( + b_q_size_1 % MARLIN_NAMESPACE_NAME::tile_size == 0, + "b_q_weight.size(1) = ", b_q_size_1, + " is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + int actual_size_n = (b_q_size_1 / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor; + host::RuntimeCheck(size_n == actual_size_n, "size_n = ", size_n, + ", actual_size_n = ", actual_size_n); + + // Verify device and strides + + // We use int4 (16 bytes) to load A, so A must aligned to 16 bytes + host::RuntimeCheck(a_stride_0 % 8 == 0, "A.stride(0) must divisible by 8"); + host::RuntimeCheck(reinterpret_cast(a) % 16 == 0, "A must aligned to 16 bytes"); + + if (a_scales != nullptr) { + host::RuntimeCheck(a_type.size_bits() == 8, + "a_scales can only be used for 8bit activation."); + } else { + host::RuntimeCheck(a_type.size_bits() != 8, + "the a_scales parameter must be passed for 8bit activation."); + } + + int device_id = 0; + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel + int sms = -1; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, device_id); + + // Alloc buffers + float *c_tmp = nullptr; + void *a_tmp = nullptr; + void *workspace = nullptr; + + int c_tmp_bytes = 0; + // Alloc C tmp buffer that is going to be used for the global reduce + + if (use_fp32_reduce) { + int max_m_block_size = (size_m + 16 - 1) / 16 * 16; + max_m_block_size = min(max_m_block_size, 64); + int max_c_tmp_size = sms * max_m_block_size * MARLIN_NAMESPACE_NAME::max_thread_n; + c_tmp_bytes = max_c_tmp_size * sizeof(float); + } + + // Detect groupsize and act_order + + // b_scales = [num_groups, size_n] + // g_idx.size(-1) == size_k && perm.size(-1) == size_k + int a_tmp_bytes = 0; + bool has_act_order = false; + + if (g_idx != nullptr && perm != nullptr) { + has_act_order = true; + } + int group_size = -1; + if (has_act_order) { + a_tmp_bytes = size_m * size_k * sizeof(scalar_t); + if (is_k_full) { + host::RuntimeCheck(num_groups > 1, "For act_order, num_groups must be > 1"); + host::RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + } else { + if (num_groups > 1) { + host::RuntimeCheck( + size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by b_scales.size(0) = ", num_groups); + group_size = size_k / num_groups; + } else { + group_size = -1; + } + } + + int workspace_bytes = sms * sizeof(int64_t); + const int total_bytes = c_tmp_bytes + a_tmp_bytes + workspace_bytes; + // ===================== 3. 单次 cudaMalloc 分配 ===================== + if (total_bytes > 0) { + + cudaMemset(total_buffer, 0, total_bytes); + } + // ===================== 4. 手动切分指针(核心!) ===================== + uint8_t *ptr = reinterpret_cast(total_buffer); + // 分配 c_tmp + if (use_fp32_reduce && c_tmp_bytes > 0) { + c_tmp = reinterpret_cast(ptr); + ptr += c_tmp_bytes; + } + // 分配 a_tmp + if (has_act_order && a_tmp_bytes > 0) { + a_tmp = ptr; + ptr += a_tmp_bytes; + } + + // 分配 workspace + if (workspace_bytes > 0) { + workspace = ptr; + ptr += workspace_bytes; + } + + if (global_scale != nullptr) { + + host::RuntimeCheck(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn, + "global_scale can only be used for nvfp4 format."); + + } else { + host::RuntimeCheck(!(b_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn), + "the global_scale parameter must be passed for nvfp4 format."); + } + // b_bias = [size_n, 1] + bool has_bias = (b_bias != nullptr); + + bool has_zp = (b_zeros != nullptr); + if (has_zp) { + host::RuntimeCheck( + b_type == vllm::kU4 || b_type == vllm::kU8, + "b_type must be u4 or u8 when has_zp = True. Got = ", b_type.str()); + + } else { + host::RuntimeCheck(b_type == vllm::kU4B8 || b_type == vllm::kU8B128 || b_type == vllm::kS4 || b_type == vllm::kS8 || b_type == vllm::kFE4M3fn || b_type == vllm::kFE2M1f, + "b_type must be uint4b8, uint8b128, int4, int8, " + "float8_e4m3fn or float4_e2m1f when has_zp = False. Got = ", + b_type.str()); + } + + if (has_zp && is_zp_float) { + if constexpr (!std::is_same::value) { + printf("Computation a_type must be float16 (half) when using float zero " + "points.\n"); + } + } + + // Verify b_zeros + if (has_zp) { + if (is_zp_float) { + // b_zeros = [num_groups, size_n] + host::RuntimeCheck(b_zeros_size_1 == size_n, + "b_zeros dim 1 = ", b_zeros_size_1, + " is not size_n = ", size_n); + host::RuntimeCheck(num_groups != -1, "num_groups must be != -1"); + } else { + + host::RuntimeCheck(b_zeros_size_1 == size_n / pack_factor, + "b_zeros dim 1 = ", b_zeros_size_1, + " is not size_n / pack_factor = ", size_n / pack_factor); + } + } + + // Verify workspace size + host::RuntimeCheck(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0, + "size_n = ", size_n, ", is not divisible by min_thread_n = ", + MARLIN_NAMESPACE_NAME::min_thread_n); + + // a_scales和global_scale都必须是float * + + if (a_type.size_bits() == 16) { + host::RuntimeCheck((a_type == c_type), "scalar type of a must be the same with c for 16 bit activation\n"); + } + + marlin::marlin_mm( + a, b_q_weight, c, c_tmp, + b_bias, a_scales, b_scales, + global_scale, b_zeros, g_idx, + perm, a_tmp, size_m, size_n, size_k, a_stride_0, + workspace, a_type, b_type, c_type, s_type, has_bias, + has_act_order, is_k_full, has_zp, num_groups, group_size, device_id, + stream, thread_k, thread_n, sms, + use_atomic_add, use_fp32_reduce, is_zp_float); + return INFINI_STATUS_SUCCESS; +} + +namespace op::awq_marlin_gemm::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { delete _opaque; } + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_bias_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t a_scales_desc, + infiniopTensorDescriptor_t global_scales_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t g_idx_desc, + infiniopTensorDescriptor_t perm_desc) { + + auto handle = reinterpret_cast(handle_); + auto result = AwqMarlinGemmInfo::create(out_desc, a_desc, b_desc, b_bias_desc, b_scales_desc, a_scales_desc, global_scales_desc, b_zeros_desc, g_idx_desc, perm_desc); + size_t size_m = a_desc->dim(0); + size_t size_k = a_desc->dim(1); + int a_tmp_bytes = 0; + int c_tmp_bytes = 0; + + int device_id = 0; + int sms = -1; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, device_id); + int workspace_bytes = sms * sizeof(int64_t); + int max_m_block_size = (size_m + 16 - 1) / 16 * 16; + max_m_block_size = min(max_m_block_size, 64); + int max_c_tmp_size = sms * max_m_block_size * MARLIN_NAMESPACE_NAME::max_thread_n; + c_tmp_bytes = max_c_tmp_size * sizeof(float); + + a_tmp_bytes = size_m * size_k * infiniSizeOf(a_desc->dtype()); + + size_t workspace_size = c_tmp_bytes + a_tmp_bytes + workspace_bytes; + + *desc_ptr = new Descriptor( + new Opaque{handle->internal()}, + result.take(), + workspace_size, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t +Descriptor::calculate( + void *workspace, size_t workspace_size, + void *c, + const void *a, + const void *b, + void *b_bias, + void *b_scales, + void *a_scales, + void *global_scales, + void *b_zeros, + void *g_idx, + void *perm, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float, + void *stream_) const { + + cudaStream_t stream = (cudaStream_t)stream_; + int size_m = static_cast(_info.size_m); + int size_k = static_cast(_info.size_k); + int size_n = static_cast(_info.size_n); + int b_q_size_0 = static_cast(_info.b_q_size_0); + int b_q_size_1 = static_cast(_info.b_q_size_1); + int b_zeros_size_1 = static_cast(_info.b_zeros_size_1); + int a_stride_0 = static_cast(_info.a_stride_0); + int num_groups = _info.num_groups; + +#define MARLIN(SCALAR_T, TDATA) \ + awq_marlin_gemm_kernel(a, c, b, b_bias, b_scales, a_scales, global_scales, b_zeros, g_idx, perm, b_q_type_id, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float, size_m, size_k, size_n, b_q_size_0, b_q_size_1, a_stride_0, b_zeros_size_1, num_groups, workspace, stream) + + if (_info.a_dtype == INFINI_DTYPE_F16) { + return MARLIN(half, half); + } else if (_info.a_dtype == INFINI_DTYPE_BF16) { + return MARLIN(__nv_bfloat16, __nv_bfloat16); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::awq_marlin_gemm::nvidia +#endif diff --git a/src/infiniop/ops/awq_marlin_gemm/nvidia/awq_marlin_gemm_nvidia.cuh b/src/infiniop/ops/awq_marlin_gemm/nvidia/awq_marlin_gemm_nvidia.cuh new file mode 100644 index 000000000..2b4a20bbe --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/nvidia/awq_marlin_gemm_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __AWQ_MARLIN_GEMM_CUDA_CUH__ +#define __AWQ_MARLIN_GEMM_CUDA_CUH__ + +#include "../awq_marlin_gemm.h" + +DESCRIPTOR(nvidia) + +#endif // __AWQ_MARLIN_GEMM_CUDA_CUH__ diff --git a/src/infiniop/ops/awq_marlin_gemm/nvidia/generate_kernels.py b/src/infiniop/ops/awq_marlin_gemm/nvidia/generate_kernels.py new file mode 100644 index 000000000..ab78d0ce2 --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/nvidia/generate_kernels.py @@ -0,0 +1,308 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import glob +import itertools +import os +import subprocess +import sys + +import jinja2 + +ARCHS = [] +SUPPORT_FP8 = False +SUPPORT_SM75 = False +SUPPORT_SM80 = False +for arch in sys.argv[1].split(","): + arch = arch[: arch.index(".") + 2].replace(".", "") + arch = int(arch) + # only SM89 and SM120 fully support + # mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32. + # SM90 and SM100 can use this PTX, but it’s simulated + # with FP16 MMA, so it cannot achieve any acceleration. + if arch in [89, 120]: + SUPPORT_FP8 = True + if arch >= 80: + SUPPORT_SM80 = True + if arch == 75: + SUPPORT_SM75 = True + +FILE_HEAD_COMMENT = """ +// auto generated by generate_kernels.py +// clang-format off +""".lstrip() + + +FILE_HEAD = ( + FILE_HEAD_COMMENT + + """ +#include "../marlin/kernel.h" +#include "../marlin/marlin_template.h" + +namespace MARLIN_NAMESPACE_NAME { +""" +) + +TEMPLATE = ( + "template __global__ void Marlin<" + "{{a_type_id}}, " + "{{b_type_id}}, " + "{{c_type_id}}, " + "{{s_type_id}}, " + "{{threads}}, " + "{{thread_m_blocks}}, " + "{{thread_n_blocks}}, " + "{{thread_k_blocks}}, " + "{{m_block_size_8}}, " + "{{stages}}, " + "{{group_blocks}}, " + "{{is_zp_float}}>" + "( MARLIN_KERNEL_PARAMS );" +) + +THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)] + +THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] + +QUANT_CONFIGS = [ + # AWQ-INT4 + { + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 + { + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 0, 2, 4, 8], + }, + # GPTQ-INT8 + { + "b_type": "kU8B128", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 0, 2, 4, 8], + }, + # FP8 + { + "b_type": "kFE4M3fn", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [-1, 8], + }, + # NVFP4 + { + "b_type": "kFE2M1f", + "s_type": "kFE4M3fn", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [1], + }, + # MXFP4 + { + "a_type": ["kBFloat16"], + "b_type": "kFE2M1f", + "s_type": "kFE8M0fnu", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": THREAD_M_BLOCKS, + "group_blocks": [2], + }, + # AWQ-INT4 with INT8 activation + { + "a_type": ["kS8"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with INT8 activation + { + "a_type": ["kS8"], + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # GPTQ-INT4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kU4B8", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # AWQ-INT4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kU4", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [-1, 2, 4, 8], + }, + # MXFP4 with FP8 activation + { + "a_type": ["kFE4M3fn"], + "b_type": "kFE2M1f", + "c_type": ["kBFloat16"], + "s_type": "kFE8M0fnu", + "thread_configs": THREAD_CONFIGS, + "thread_m_blocks": [1, 2, 3, 4], + "group_blocks": [2], + }, +] + + +def remove_old_kernels(): + for filename in glob.glob(os.path.dirname(__file__) + "/*kernel_*.cu"): + subprocess.call(["rm", "-f", filename]) + + filename = os.path.dirname(__file__) + "/kernel_selector.h" + subprocess.call(["rm", "-f", filename]) + + +def generate_new_kernels(): + result_dict = {} + sm_75_result_dict = {} + + for quant_config in QUANT_CONFIGS: + c_types = quant_config.get("c_type", ["kFloat16", "kBFloat16"]) + a_types = quant_config.get("a_type", ["kFloat16", "kBFloat16"]) + b_type = quant_config["b_type"] + is_zp_float = quant_config.get("is_zp_float", False) + all_group_blocks = quant_config["group_blocks"] + all_m_blocks = quant_config["thread_m_blocks"] + all_thread_configs = quant_config["thread_configs"] + + for a_type, c_type in itertools.product(a_types, c_types): + if not SUPPORT_FP8 and a_type == "kFE4M3fn": + continue + if "16" in a_type and "16" in c_type and a_type != c_type: + continue + s_type = quant_config.get("s_type", c_type) + if (a_type, b_type, c_type) not in result_dict: + result_dict[(a_type, b_type, c_type)] = [] + if a_type in ["kFloat16", "kS8"] and c_type == "kFloat16": + sm_75_result_dict[(a_type, b_type, c_type)] = [] + + for group_blocks, m_blocks, thread_configs in itertools.product( + all_group_blocks, all_m_blocks, all_thread_configs + ): + thread_k, thread_n, threads = thread_configs + + if threads == 256: + # for small batch (m_blocks == 1), + # we only need (128, 128, 256) + # for large batch (m_blocks > 1), + # we only need (64, 256, 256) + if m_blocks <= 1 and (thread_k, thread_n) != (128, 128): + continue + if m_blocks > 1 and (thread_k, thread_n) != (64, 256): + continue + + config = { + "threads": threads, + "s_type": s_type, + "thread_m_blocks": max(m_blocks, 1), + "thread_k_blocks": thread_k // 16, + "thread_n_blocks": thread_n // 16, + "m_block_size_8": "true" if m_blocks == 0.5 else "false", + "stages": 4, + "group_blocks": group_blocks, + "is_zp_float": "true" if is_zp_float else "false", + } + + if SUPPORT_SM80: + result_dict[(a_type, b_type, c_type)].append(config) + if (a_type, b_type, c_type) in sm_75_result_dict and SUPPORT_SM75: + config_sm75 = config.copy() + config_sm75["stages"] = 2 + sm_75_result_dict[(a_type, b_type, c_type)].append(config_sm75) + + kernel_selector_str = FILE_HEAD_COMMENT + + for result_dict_tmp in [result_dict, sm_75_result_dict]: + for (a_type, b_type, c_type), config_list in result_dict_tmp.items(): + all_template_str_list = [] + if not config_list: + continue + for config in config_list: + s_type = config["s_type"] + template_str = jinja2.Template(TEMPLATE).render( + a_type_id=f"vllm::{a_type}.id()", + b_type_id=f"vllm::{b_type}.id()", + c_type_id=f"vllm::{c_type}.id()", + s_type_id=f"vllm::{s_type}.id()", + **config, + ) + all_template_str_list.append(template_str) + + conditions = [ + f"a_type == vllm::{a_type}", + f"b_type == vllm::{b_type}", + f"c_type == vllm::{c_type}", + f"s_type == vllm::{s_type}", + f"threads == {config['threads']}", + f"thread_m_blocks == {config['thread_m_blocks']}", + f"thread_n_blocks == {config['thread_n_blocks']}", + f"thread_k_blocks == {config['thread_k_blocks']}", + f"m_block_size_8 == {config['m_block_size_8']}", + f"stages == {config['stages']}", + f"group_blocks == {config['group_blocks']}", + f"is_zp_float == {config['is_zp_float']}", + ] + conditions = " && ".join(conditions) + + if kernel_selector_str == FILE_HEAD_COMMENT: + kernel_selector_str += f"if ({conditions})\n kernel = " + else: + kernel_selector_str += f"else if ({conditions})\n kernel = " + + kernel_template2 = ( + "Marlin<{{a_type_id}}, {{b_type_id}}, {{c_type_id}}, " + "{{s_type_id}}, {{threads}}, {{thread_m_blocks}}, " + "{{thread_n_blocks}}, {{thread_k_blocks}}, " + "{{m_block_size_8}}, {{stages}}, {{group_blocks}}, " + "{{is_zp_float}}>;" + ) + + kernel_selector_str += ( + jinja2.Template(kernel_template2).render( + a_type_id=f"vllm::{a_type}.id()", + b_type_id=f"vllm::{b_type}.id()", + c_type_id=f"vllm::{c_type}.id()", + s_type_id=f"vllm::{s_type}.id()", + **config, + ) + + "\n" + ) + + file_content = FILE_HEAD + "\n\n" + file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" + if a_type == "kFE4M3fn": + filename = f"sm89_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + elif result_dict_tmp is sm_75_result_dict: + filename = f"sm75_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + else: + filename = f"sm80_kernel_{a_type[1:]}_{b_type[1:]}_{c_type[1:]}.cu" + + filename = filename.lower() + + with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: + f.write(file_content) + + if not SUPPORT_FP8 and kernel_selector_str != FILE_HEAD_COMMENT: + kernel_selector_str += ( + "else if (a_type == vllm::kFE4M3fn)\n" + " host::RuntimeCheck(false, " + '"marlin kernel with fp8 activation is not built.");' + ) + + with open(os.path.join(os.path.dirname(__file__), "kernel_selector.h"), "w") as f: + f.write(kernel_selector_str) + + +if __name__ == "__main__": + remove_old_kernels() + generate_new_kernels() diff --git a/src/infiniop/ops/awq_marlin_gemm/nvidia/kernel.cuh b/src/infiniop/ops/awq_marlin_gemm/nvidia/kernel.cuh new file mode 100644 index 000000000..be4081be7 --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/nvidia/kernel.cuh @@ -0,0 +1,542 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#ifndef MARLIN_NAMESPACE_NAME +#define MARLIN_NAMESPACE_NAME marlin +#endif + +#include "../core/utils.h" +#include "../marlin/kernel.h" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace marlin { + +__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; + +using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 + +__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, int size_m, + int size_k, int lda, int block_rows) {} + +} // namespace marlin + +template +void awq_marlin_gemm_kernel( + const void *a, + void *c, + const void *b_q_weight, + void *b_bias, + void *b_scales, + void *a_scales, + void *global_scale, + void *b_zeros, + void *g_idx, + void *perm, + int64_t b_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float, + int size_m, + int size_k, + int size_n, + int b_q_size_0, + int b_q_size_1, + int a_stride_0, + int b_zeros_size_1, + int num_groups, + cudaStream_t stream) { +} + +#else + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, int size_m, + int size_k, int lda, int block_rows) { + auto start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int input_row_stride = lda * sizeof(half) / 16; + int output_row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int input_offset = row * input_row_stride; + int output_offset = row * output_row_stride; + + half const *a_row_half = reinterpret_cast(a_int4_ptr + input_offset); + half *out_half = reinterpret_cast(out_int4_ptr + output_offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +typedef struct +{ + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}}; + +typedef struct +{ + int blocks_per_sm; + thread_config_t tb_cfg; +} exec_config_t; + +int get_scales_cache_size(thread_config_t const &th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full, int stages) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * stages; + } +} + +int get_kernel_cache_size(thread_config_t const &th_config, int thread_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int has_zp, bool is_zp_float, bool is_a_8bit, + int stages) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + int tb_m = thread_m_blocks * 16; + int sh_a_size = stages * (tb_m * tb_k) * (is_a_8bit ? 1 : 2); + int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4; + int sh_red_size = tb_m * (tb_n + 8) * 2; + int sh_bias_size = tb_n * 2; + int tmp_size = (sh_b_size > sh_red_size ? sh_red_size : sh_b_size) + sh_bias_size; + tmp_size = max(max(sh_b_size, sh_red_size), tmp_size); + + int sh_s_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full, stages); + int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0; + int sh_zp_size = 0; + if (has_zp) { + if (is_zp_float) { + sh_zp_size = sh_s_size; + } else if (num_bits == 4) { + sh_zp_size = sh_s_size / 4; + } else if (num_bits == 8) { + sh_zp_size = sh_s_size / 2; + } + } + + int total_size = tmp_size + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size; + + return total_size; +} + +bool is_valid_config(thread_config_t const &th_config, int thread_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int has_zp, bool is_zp_float, bool is_a_8bit, int stages, + int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Check that pipeline fits into cache + int cache_size = get_kernel_cache_size( + th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, has_zp, is_zp_float, is_a_8bit, stages); + return cache_size <= max_shared_mem; +} + +MarlinFuncPtr get_marlin_kernel( + const vllm::ScalarType a_type, const vllm::ScalarType b_type, + const vllm::ScalarType c_type, const vllm::ScalarType s_type, + int thread_m_blocks, int thread_n_blocks, int thread_k_blocks, + bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks, + int threads, bool is_zp_float, int stages) { + int num_bits = b_type.size_bits(); + auto kernel = MarlinDefault; + +#include "kernel_selector.h" + + return kernel; +} + +exec_config_t determine_exec_config( + const vllm::ScalarType &a_type, const vllm::ScalarType &b_type, + const vllm::ScalarType &c_type, const vllm::ScalarType &s_type, int prob_m, + int prob_n, int prob_k, int thread_m_blocks, bool m_block_size_8, + int num_bits, int group_size, bool has_act_order, bool is_k_full, + bool has_zp, bool is_zp_float, int is_a_8bit, int stages, + int max_shared_mem, int sms) { + exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; + thread_config_t *thread_configs = thread_m_blocks > 1 + ? large_batch_thread_configs + : small_batch_thread_configs; + int thread_configs_size = thread_m_blocks > 1 + ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) + : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + + for (int i = 0; i < thread_configs_size; i++) { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, has_zp, + is_zp_float, is_a_8bit, stages, + max_shared_mem - 512)) { + continue; + } + + int cache_size = get_kernel_cache_size(th_config, thread_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, has_zp, + is_zp_float, is_a_8bit, stages); + + int group_blocks = 0; + if (!has_act_order) { + group_blocks = group_size == -1 ? -1 : group_size / 16; + } + + auto kernel = get_marlin_kernel(a_type, b_type, c_type, s_type, thread_m_blocks, + th_config.thread_n / 16, th_config.thread_k / 16, + m_block_size_8, has_act_order, has_zp, group_blocks, + th_config.num_threads, is_zp_float, stages); + + if (kernel == MarlinDefault) { + continue; + } + + return {1, th_config}; + } + + return exec_cfg; +} + +void marlin_mm(const void *A, const void *B, void *C, void *C_tmp, void *b_bias, + void *a_s, void *b_s, void *g_s, void *zp, void *g_idx, + void *perm, void *a_tmp, int prob_m, int prob_n, int prob_k, + int lda, void *workspace, vllm::ScalarType const &a_type, + vllm::ScalarType const &b_type, vllm::ScalarType const &c_type, + vllm::ScalarType const &s_type, bool has_bias, + bool has_act_order, bool is_k_full, bool has_zp, int num_groups, + int group_size, int dev, cudaStream_t stream, int thread_k_init, + int thread_n_init, int sms, bool use_atomic_add, + bool use_fp32_reduce, bool is_zp_float) { + bool is_a_8bit = a_type.size_bits() == 8; + host::RuntimeCheck(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + host::RuntimeCheck(group_size != -1); + group_blocks = group_size / 16; + host::RuntimeCheck(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } else { + host::RuntimeCheck(group_size == 0); + group_blocks = 0; + } + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + host::RuntimeCheck(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + } + + int num_bits = b_type.size_bits(); + const int4 *A_ptr = (const int4 *)A; + const int4 *B_ptr = (const int4 *)B; + int4 *C_ptr = (int4 *)C; + int4 *C_tmp_ptr = (int4 *)C_tmp; + + const int4 *bias_ptr = (const int4 *)b_bias; + const float *a_s_ptr = (const float *)a_s; + const int4 *b_s_ptr = (const int4 *)b_s; + const float *g_s_ptr = (const float *)g_s; + + const int4 *zp_ptr = (const int4 *)zp; + const int *g_idx_ptr = (const int *)g_idx; + const int *perm_ptr = (const int *)perm; + int4 *a_tmp_ptr = (int4 *)a_tmp; + int *locks = (int *)workspace; + + if (has_act_order) { + // Permute A columns + int block_rows = div_ceil(prob_m, sms); + // avoid ">>>" being formatted to "> > >" + // clang-format off + permute_cols_kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows); + // clang-format on + A_ptr = a_tmp_ptr; + lda = prob_k; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + host::RuntimeCheck(max_shared_mem > 0); + + int major_capability, minor_capability; + cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor, + dev); + cudaDeviceGetAttribute(&minor_capability, cudaDevAttrComputeCapabilityMinor, + dev); + host::RuntimeCheck(major_capability * 10 + minor_capability >= 75, + "marlin kernel only support Turing or newer GPUs."); + int stages = 4; + if (major_capability == 7 && minor_capability == 5) { + stages = 2; + host::RuntimeCheck(a_type == vllm::kFloat16 || a_type == vllm::kS8, + "Turing only support FP16 or INT8 activation."); + } + if (a_type == vllm::kFE4M3fn) { + host::RuntimeCheck( + major_capability * 10 + minor_capability == 89 || major_capability * 10 + minor_capability == 120, + "Marlin W4A8-FP8 only support SM89 or SM120 device (It is slower than " + "Marlin W4A16 on other devices)."); + } + + int max_par = 16; + if (prob_n <= 4096) { + max_par = 16 * 8; + } + int max_shared_mem_new = max_shared_mem; + int rest_m = prob_m; + int max_thread_m_blocks = 4; + while (rest_m) { + int par_count = rest_m / (max_thread_m_blocks * 16); + if (par_count > max_par) { + par_count = max_par; + } + int prob_m_split = par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m; + + int thread_k = thread_k_init; + int thread_n = thread_n_init; + + int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); + int m_block_size_8 = prob_m_split <= 8 && a_type.size_bits() == 16; + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + host::RuntimeCheck(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + host::RuntimeCheck(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + } else { + // Auto config + exec_cfg = determine_exec_config( + a_type, b_type, c_type, s_type, prob_m_split, prob_n, prob_k, + thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, is_a_8bit, stages, max_shared_mem, + sms); + thread_tfg = exec_cfg.tb_cfg; + if (thread_tfg.thread_n != -1) { + if (prob_n / thread_tfg.thread_n * div_ceil(prob_m_split, thread_m_blocks * 16) * 4 <= sms) { + if (is_valid_config({128, 64, 128}, thread_m_blocks, prob_m_split, + prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, has_zp, is_zp_float, + is_a_8bit, stages, max_shared_mem_new)) { + thread_tfg = {128, 64, 128}; + exec_cfg = {1, thread_tfg}; + } + } + } + + if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { + max_thread_m_blocks--; + continue; + } + } + + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) { + max_shared_mem_new = max_shared_mem / exec_cfg.blocks_per_sm - 1024; + } + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + host::RuntimeCheck( + is_valid_config(thread_tfg, thread_m_blocks, prob_m_split, prob_n, + prob_k, num_bits, group_size, has_act_order, is_k_full, + has_zp, is_zp_float, is_a_8bit, stages, + max_shared_mem_new), + "Invalid thread config: thread_m_blocks = ", thread_m_blocks, + ", thread_k = ", thread_tfg.thread_k, + ", thread_n = ", thread_tfg.thread_n, + ", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, + ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", prob_m_split = ", prob_m_split, ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, + ", stages = ", stages, ", max_shared_mem_new = ", max_shared_mem_new); + + auto kernel = get_marlin_kernel( + a_type, b_type, c_type, s_type, thread_m_blocks, thread_n_blocks, + thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks, + num_threads, is_zp_float, stages); + + if (kernel == MarlinDefault) { + host::RuntimeCheck(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, + ", ", prob_k, "]", ", has_act_order = ", has_act_order, + ", num_groups = ", num_groups, ", group_size = ", group_size, + ", prob_m_split = ", prob_m_split, + ", thread_m_blocks = ", thread_m_blocks, + ", thread_n_blocks = ", thread_n_blocks, + ", thread_k_blocks = ", thread_k_blocks, + ", num_threads = ", num_threads, ", num_bits = ", num_bits); + } + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + max_shared_mem_new); + + bool part_use_atomic_add = use_atomic_add && div_ceil(prob_m_split, 64) * prob_n <= 2048; + + // avoid ">>>" being formatted to "> > >" + // clang-format off + kernel<<>>( + A_ptr, B_ptr, C_ptr, C_tmp_ptr, bias_ptr, a_s_ptr, b_s_ptr, g_s_ptr, zp_ptr, + g_idx_ptr, num_groups, + prob_m_split, prob_n, prob_k, lda, locks, has_bias, part_use_atomic_add, + use_fp32_reduce, max_shared_mem_new); + // clang-format on + + bool is_a_8bit = a_type.size_bits() == 8; + A_ptr += prob_m_split * (lda / (is_a_8bit ? 16 : 8)); + a_s_ptr += prob_m_split; + C_ptr += prob_m_split * (prob_n / 8); + rest_m -= prob_m_split; + } +} + +} // namespace marlin + +#endif diff --git a/src/infiniop/ops/awq_marlin_gemm/operator.cc b/src/infiniop/ops/awq_marlin_gemm/operator.cc new file mode 100644 index 000000000..7ef8246b1 --- /dev/null +++ b/src/infiniop/ops/awq_marlin_gemm/operator.cc @@ -0,0 +1,126 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/awq_marlin_gemm.h" + +#if defined ENABLE_NVIDIA_API +#include "nvidia/awq_marlin_gemm_nvidia.cuh" +#endif + +__INFINI_C infiniStatus_t infiniopCreateAwqMarlinGemmDescriptor( + infiniopHandle_t handle, + infiniopAwqMarlinGemmDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t b_bias_desc, + infiniopTensorDescriptor_t b_scales_desc, + infiniopTensorDescriptor_t a_scales_desc, + infiniopTensorDescriptor_t global_scales_desc, + infiniopTensorDescriptor_t b_zeros_desc, + infiniopTensorDescriptor_t g_idx_desc, + infiniopTensorDescriptor_t perm_desc) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::awq_marlin_gemm::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, \ + a_desc, \ + b_desc, \ + b_bias_desc, \ + b_scales_desc, \ + a_scales_desc, \ + global_scales_desc, \ + b_zeros_desc, \ + g_idx_desc, \ + perm_desc) + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__INFINI_C infiniStatus_t infiniopGetAwqMarlinGemmWorkspaceSize(infiniopAwqMarlinGemmDescriptor_t desc, + size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET +} + +__INFINI_C infiniStatus_t infiniopAwqMarlinGemm( + infiniopAwqMarlinGemmDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *b_bias, + void *b_scales, + void *a_scales, + void *global_scales, + void *b_zeros, + void *g_idx, + void *perm, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, c, a, b, b_bias, b_scales, a_scales, global_scales, b_zeros, g_idx, perm, b_q_type_id, is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float, stream) + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__INFINI_C infiniStatus_t +infiniopDestroyAwqMarlinGemmDescriptor(infiniopAwqMarlinGemmDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} + +// #endif diff --git a/test/infiniop/awq_marlin_gemm.py b/test/infiniop/awq_marlin_gemm.py new file mode 100644 index 000000000..059166b24 --- /dev/null +++ b/test/infiniop/awq_marlin_gemm.py @@ -0,0 +1,837 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + TestWorkspace, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, + to_torch_dtype, +) +import itertools +from libinfiniop.scalar_type import scalar_types, ScalarType +from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union +import numpy as np + + +_TEST_CASES_SUBSET_INPUT = [ + # (size_m, size_k, size_n, group_size, quant_type) + (32, 1024, 2048, 128, scalar_types.uint4b8), +] + +_TEST_CASES_WITH_BIAS = [ + # (size_m, size_k, size_n, group_size, quant_type) + (1, 1024, 2048, 128, scalar_types.uint4b8), + (256, 1024, 2048, 128, scalar_types.uint4b8), +] + +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16] + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def rand_data(shape, dtype, device): + return torch.randn(shape, dtype=dtype, device=device) + + +def get_scale_perms(): + scale_perm: list[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: list[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor: + origin_shape = s.shape + _, scale_perm_single = get_scale_perms() + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + return s.reshape(*origin_shape).contiguous() + + +def quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int | None, + zero_points: bool = False, + ref_zero_points_after_scales: bool = False, +): + assert ( + quant_type.is_integer() + ), "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, ( + "to have group zero points, group_size must be provided " + "(-1 group_size is channelwise)" + ) + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = ( + torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int() + ) + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)), + ) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +def permute_rows( + q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: torch.Tensor | None = None, +): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size,), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def gptq_quantize_weights( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: torch.Tensor | None = None, +): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert ( + quant_type in SUPPORTED_GPTQ_QUANT_TYPES + ), f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k + ) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to(dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def get_weight_perm(num_bits: int, is_a_8bit: bool = False): + perm_list: list[int] = [] + if is_a_8bit: + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + 4 * (i % 4 + 4), + 4 * (i % 4 + 4) + 1, + 4 * (i % 4 + 4) + 2, + 4 * (i % 4 + 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(2): + perm_list.extend([p + 512 * j for p in perm1]) + else: + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + if is_a_8bit: # noqa: SIM108 + interleave = np.array([0, 4, 1, 5, 2, 6, 3, 7]) + else: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + if is_a_8bit: # noqa: SIM108 + interleave = np.array([0, 1, 2, 3]) + else: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + + +def marlin_permute_weights( + q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE, is_a_8bit=False +): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + if is_a_8bit: + # Permute weights to 32x32 marlin tiles + q_w = q_w.reshape((size_k // (tile * 2), tile * 2, size_n // tile, tile)) + else: + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm, is_a_8bit=False): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm, is_a_8bit=is_a_8bit) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int, is_a_8bit: bool = False +) -> torch.Tensor: + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1 and not is_a_8bit: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: torch.Tensor | None = None, + input_dtype: torch.dtype | None = None, +): + is_a_8bit = input_dtype is not None and input_dtype.itemsize == 1 + + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm + ) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits, is_a_8bit) + marlin_q_w = marlin_weights( + q_w, size_k, size_n, num_bits, weight_perm, is_a_8bit=is_a_8bit + ) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size, is_a_8bit=is_a_8bit) + + if input_dtype == torch.float8_e4m3fn and quant_type == scalar_types.uint4b8: + print("not support dtype") + return + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref) + ) + + +def awq_marlin_gemm_torch(a_input, w_ref, b_bias): + if b_bias == None: + return torch.matmul(a_input, w_ref) + else: + return torch.matmul(a_input, w_ref) + b_bias.view(1, -1) + + +def test_marlin_gemm_subset_input( + handle, + device, + size_m, + size_k, + size_n, + group_size, + quant_type, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing awq_marlin_gemm_subset_input on {device} with M-K-N:({size_m, size_k, size_n}), group_size:{group_size}, dtype:{InfiniDtypeNames[dtype]}" + ) + big_m = size_m * 2 + big_k = size_k * 2 + test_dtype = to_torch_dtype(dtype) + + a_input = torch.randn((big_m, big_k), dtype=test_dtype)[ + 8 : size_m + 8, 8 : size_k + 8 + ] + A = TestTensor( + a_input.shape, + a_input.stride(), + dtype, + device, + mode="manual", + set_tensor=a_input, + ) + b_weight = TestTensor((size_k, size_n), None, dtype, device) + + w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + b_weight.torch_tensor(), quant_type, group_size, False + ) + + marlin_zp = marlin_make_empty_g_idx(marlin_s.device) + + ans = awq_marlin_gemm_torch(A.torch_tensor(), w_ref, None) + # print("w", w_ref.shape, w_ref.dtype, w_ref.stride()) + # print("b", marlin_q_w.shape, marlin_q_w.dtype, marlin_q_w.stride()) + # print("b_scales", marlin_s.shape, marlin_s.dtype, marlin_s.stride()) + # print("g_idx", g_idx.shape, g_idx.dtype, g_idx.stride()) + # print("perm", sort_indices.shape, sort_indices.dtype, sort_indices.stride()) + # print("g_idx", marlin_zp) + output = TestTensor(ans.shape, None, dtype, device, mode="zeros") + + B = TestTensor( + marlin_q_w.shape, + marlin_q_w.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=marlin_q_w, + ) + b_bias = None + b_scales = TestTensor( + marlin_s.shape, + marlin_s.stride(), + dtype, + device, + mode="manual", + set_tensor=marlin_s, + ) + a_scales = None + global_scales = None + if marlin_zp is not None: + b_zeros = TestTensor( + marlin_zp.shape, + marlin_zp.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=marlin_zp, + ) + else: + b_zeros = None + if g_idx is not None: + b_g_idx = TestTensor( + g_idx.shape, + g_idx.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=g_idx, + ) + else: + b_g_idx = None + if sort_indices is not None: + perm = TestTensor( + sort_indices.shape, + sort_indices.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=sort_indices, + ) + else: + perm = None + is_k_full = True + use_atomic_add = False + use_fp32_reduce = True + is_zp_float = False + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateAwqMarlinGemmDescriptor( + handle, + ctypes.byref(descriptor), + output.descriptor, + A.descriptor, + B.descriptor, + b_bias.descriptor if b_bias is not None else None, + b_scales.descriptor, + a_scales.descriptor if a_scales is not None else None, + global_scales.descriptor if global_scales is not None else None, + b_zeros.descriptor if b_zeros is not None else None, + b_g_idx.descriptor if b_g_idx is not None else None, + perm.descriptor if perm is not None else None, + ) + ) + + # Invalidate descriptors (same pattern as other tests) + for tensor in [ + output, + A, + B, + b_bias, + b_scales, + a_scales, + global_scales, + b_zeros, + b_g_idx, + perm, + ]: + if tensor is not None: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetAwqMarlinGemmWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + def lib_awq_marlin_gemm(): + check_error( + LIBINFINIOP.infiniopAwqMarlinGemm( + descriptor, + workspace.data(), + workspace_size.value, + output.data(), + A.data(), + B.data(), + b_bias.data() if b_bias is not None else None, + b_scales.data(), + a_scales.data() if a_scales is not None else None, + global_scales.data() if global_scales is not None else None, + b_zeros.data() if b_zeros is not None else None, + b_g_idx.data() if b_g_idx is not None else None, + perm.data() if perm is not None else None, + quant_type.id, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, + None, + ) + ) + + lib_awq_marlin_gemm() + + max_diff = compute_max_diff(output.actual_tensor(), ans) + assert max_diff < 0.04 + + if PROFILE: + profile_operation( + "PyTorch", + lambda: awq_marlin_gemm_torch(A.torch_tensor(), w_ref, None), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", + lambda: lib_awq_marlin_gemm(), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + + check_error(LIBINFINIOP.infiniopDestroyAwqMarlinGemmDescriptor(descriptor)) + + +def test_marlin_gemm_with_bias( + handle, + device, + size_m, + size_k, + size_n, + group_size, + quant_type, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing awq_marlin_gemm_with_bias on {device} with M-K-N:({size_m, size_k, size_n}), group_size:{group_size}, dtype:{InfiniDtypeNames[dtype]}" + ) + + test_dtype = to_torch_dtype(dtype) + + A = TestTensor((size_m, size_k), None, dtype, device) + b_weight = TestTensor((size_k, size_n), None, dtype, device) + b_bias = TestTensor((size_n,), None, dtype, device) + + marlin_bias = marlin_permute_bias(b_bias.torch_tensor()) + w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + b_weight.torch_tensor(), quant_type, group_size, False + ) + + marlin_zp = marlin_make_empty_g_idx(marlin_s.device) + + ans = awq_marlin_gemm_torch(A.torch_tensor(), w_ref, b_bias.torch_tensor()) + # print("w", w_ref.shape, w_ref.dtype, w_ref.stride()) + # print("b", marlin_q_w.shape, marlin_q_w.dtype, marlin_q_w.stride()) + # print("b_scales", marlin_s.shape, marlin_s.dtype, marlin_s.stride()) + # print("g_idx", g_idx.shape, g_idx.dtype, g_idx.stride()) + # print("perm", sort_indices.shape, sort_indices.dtype, sort_indices.stride()) + # print("g_idx", marlin_zp) + output = TestTensor(ans.shape, None, dtype, device, mode="zeros") + + B = TestTensor( + marlin_q_w.shape, + marlin_q_w.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=marlin_q_w, + ) + + b_scales = TestTensor( + marlin_s.shape, + marlin_s.stride(), + dtype, + device, + mode="manual", + set_tensor=marlin_s, + ) + a_scales = None + global_scales = None + if marlin_zp is not None: + b_zeros = TestTensor( + marlin_zp.shape, + marlin_zp.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=marlin_zp, + ) + else: + b_zeros = None + if g_idx is not None: + b_g_idx = TestTensor( + g_idx.shape, + g_idx.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=g_idx, + ) + else: + b_g_idx = None + if sort_indices is not None: + perm = TestTensor( + sort_indices.shape, + sort_indices.stride(), + InfiniDtype.I32, + device, + mode="manual", + set_tensor=sort_indices, + ) + else: + perm = None + is_k_full = True + use_atomic_add = False + use_fp32_reduce = True + is_zp_float = False + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateAwqMarlinGemmDescriptor( + handle, + ctypes.byref(descriptor), + output.descriptor, + A.descriptor, + B.descriptor, + b_bias.descriptor if b_bias is not None else None, + b_scales.descriptor, + a_scales.descriptor if a_scales is not None else None, + global_scales.descriptor if global_scales is not None else None, + b_zeros.descriptor if b_zeros is not None else None, + b_g_idx.descriptor if b_g_idx is not None else None, + perm.descriptor if perm is not None else None, + ) + ) + + # Invalidate descriptors (same pattern as other tests) + for tensor in [ + output, + A, + B, + b_bias, + b_scales, + a_scales, + global_scales, + b_zeros, + b_g_idx, + perm, + ]: + if tensor is not None: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetAwqMarlinGemmWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + def lib_awq_marlin_gemm(): + check_error( + LIBINFINIOP.infiniopAwqMarlinGemm( + descriptor, + workspace.data(), + workspace_size.value, + output.data(), + A.data(), + B.data(), + b_bias.data() if b_bias is not None else None, + b_scales.data(), + a_scales.data() if a_scales is not None else None, + global_scales.data() if global_scales is not None else None, + b_zeros.data() if b_zeros is not None else None, + b_g_idx.data() if b_g_idx is not None else None, + perm.data() if perm is not None else None, + quant_type.id, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, + None, + ) + ) + + lib_awq_marlin_gemm() + + max_diff = compute_max_diff(output.actual_tensor(), ans) + assert max_diff < 0.04 + + if PROFILE: + profile_operation( + "PyTorch", + lambda: awq_marlin_gemm_torch( + A.torch_tensor(), w_ref, b_bias.torch_tensor() + ), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", + lambda: lib_awq_marlin_gemm(), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + + check_error(LIBINFINIOP.infiniopDestroyAwqMarlinGemmDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator( + device, + test_marlin_gemm_subset_input, + _TEST_CASES_SUBSET_INPUT, + _TENSOR_DTYPES, + ) + test_operator( + device, test_marlin_gemm_with_bias, _TEST_CASES_WITH_BIAS, _TENSOR_DTYPES + ) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 26ade4c2a..27979daae 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -1353,6 +1353,60 @@ def gptq_qyblas_gemm_(lib): infiniopOperatorDescriptor_t, ] + +@OpRegister.operator +def awq_marlin_gemm_(lib): + lib.infiniopCreateAwqMarlinGemmDescriptor.restype = c_int32 + lib.infiniopCreateAwqMarlinGemmDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetAwqMarlinGemmWorkspaceSize.restype = c_int32 + lib.infiniopGetAwqMarlinGemmWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopAwqMarlinGemm.restype = c_int32 + lib.infiniopAwqMarlinGemm.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_int64, + c_bool, + c_bool, + c_bool, + c_bool, + c_void_p, + ] + + lib.infiniopDestroyAwqMarlinGemmDescriptor.restype = c_int32 + lib.infiniopDestroyAwqMarlinGemmDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def softplus_(lib): lib.infiniopCreateSoftplusDescriptor.restype = c_int32