Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,9 @@ void performTest(const TestParams& params) {
const bool has_fp8 = isFp8Type(atype) || isFp8Type(btype);
const bool use_mxfp8 = params.scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING;

cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, 0);

if (use_mxfp8)
{
if (!has_fp8) {
Expand All @@ -471,14 +474,15 @@ void performTest(const TestParams& params) {
if (params.m % 16 || params.n % 16) {
GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16";
}
if (params.k % 128) {
GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128";
size_t required_k_multiple = 128;
#ifdef __HIP_PLATFORM_AMD__
required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128;
#endif
if (params.k % required_k_multiple) {
GTEST_SKIP() << "MXFP8 requires K to be a multiple of " << required_k_multiple;
}
}

cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, 0);

#ifdef __HIP_PLATFORM_AMD__

#if HIP_VERSION < 70200000
Expand Down Expand Up @@ -695,16 +699,20 @@ void performDqTest(const TestParams &params) {
GTEST_ASSERT_TRUE(isFp8Type(atype) && isFp8Type(btype)) << "FP8/BF8 input datatype is expected";
GTEST_ASSERT_FALSE(isFp8Type(dtype)) << "Non FP8/BF8 output datatype is expected";

cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, 0);

if (params.m % 16 || params.n % 16) {
GTEST_SKIP() << "MXFP8 requires M & N to be multiples of 16";
}
if (params.k % 128) {
GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128";
size_t required_k_multiple = 128;
#ifdef __HIP_PLATFORM_AMD__
required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128;
#endif
if (params.k % required_k_multiple) {
GTEST_SKIP() << "MXFP8 requires K to be a multiple of " << required_k_multiple;
}

cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, 0);

bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5) || prop.major >= 12;
if (!mxfp8_supported) {
GTEST_SKIP() << "MXFP8 is not supported in current config";
Expand Down
11 changes: 9 additions & 2 deletions transformer_engine/common/gemm/rocm_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <cstring>

#include "../common.h"
#include "../util/cuda_runtime.h"
#include "../util/vectorized_pointwise.h"
#include "../util/logging.h"

Expand Down Expand Up @@ -1736,10 +1737,16 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK((is_transb ? B0 : B1) == k,
"GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1,
")");
// Check that K is a multiple of 128, and M/N are multiples of 16 for MXFP8 GEMM
// Check that K is compatible with the MXFP8 scale layout, and M/N are multiples of 16
if (inputA->scaling_mode == NVTE_MXFP8_1D_SCALING || inputB->scaling_mode == NVTE_MXFP8_1D_SCALING) {
const bool is_gfx1250 = cuda::sm_arch() == 125;
// TODO: Also use 32 for gfx950 once hipBLASLt (and TE) support MXFP8 GEMM with
// swizzled scales on that architecture.
const int required_k_multiple = is_gfx1250 ? 32 : 128;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a TODO here to change this for gfx950 after scale preswizzle is in hipblasLt.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added in 3a7dd8f

NVTE_CHECK(inputBias->data.dptr == nullptr, "MXFP8 GEMM does not yet support bias.");
NVTE_CHECK((k % 128) == 0, "GEMM K dimension must be multiple of 128 for MXFP8 scaling (got K=", k, ")");
NVTE_CHECK((k % required_k_multiple) == 0,
"GEMM K dimension must be multiple of ", required_k_multiple,
" for MXFP8 scaling (got K=", k, ")");
NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")");
NVTE_CHECK((n % 16) == 0, "GEMM N dimension must be multiple of 16 for MXFP8 scaling (got N=", n, ")");
}
Expand Down
Loading