Skip to content

enable blockwise FP8 quantization on rocm#609

Open
asdfvg123 wants to merge 9 commits into
devfrom
yeonsoo/blockwise_fp8
Open

enable blockwise FP8 quantization on rocm#609
asdfvg123 wants to merge 9 commits into
devfrom
yeonsoo/blockwise_fp8

Conversation

@asdfvg123

Copy link
Copy Markdown

Description

Please include a brief summary of the changes, relevant motivation and context.

Enable blockwise FP8 quantization on rocm

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

remove HIP guard in quantization.py
guard kernels using TMA in quantization.
add branch to handle rocm for different threads per wave

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

# TODO replace with call to fp8.py when recipe added.
recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8)
if IS_HIP_EXTENSION:
recipe_available = get_device_compute_capability() >= (9, 0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this be always True on ROCm TE?

@asdfvg123 asdfvg123 Jun 4, 2026

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test targets MI300 and MI350 so I set to (9,0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe MI250 is (9,0) so this should be a > rather than a >=, or (9,4)

@@ -1 +1 @@
/*************************************************************************

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs AMD copyright

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Comment on lines +8 to +24
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h>
#include <cuda_runtime.h>

#include <cfloat>
#ifndef __HIP_PLATFORM_AMD__
#include <cuda/barrier>
#endif

#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#ifndef __HIP_PLATFORM_AMD__
#include "common/util/ptx.cuh"
#endif

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These #includes should be already disabled via hipify, so probably no need for the #ifndefs here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment thread transformer_engine/common/common.h Outdated
Comment on lines +639 to +640
static constexpr float max = 448.0f;
static constexpr float max_inverse = 1.0 / max;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary? fp8e4m3 max depends on the device type on AMD.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quantize_transpose_square_blockwise.cu and quantize_transpose_vector_blockwise.cu use
compute_scale_from_types<IType, fp8e4m3> for the first time, which exposed a latent bug in common.h

The #else branch of TypeExtrema<fp8e4m3> declared max as a static float,
This caused the constexpr static float max_finite_value initializer in TypeInfo in the same file to fail when the template was instantiated on the host.

The fix uses HIP_FP8_TYPE_FNUZ, used in hip_float8.h for selecting FNUZ at compile time, to make the host-pass branch constexpr as well.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the value really used host size, it should be runtime detected. If it is only for host translation of GPU code (i.e.. results are discarded), you can keep 448, no extra ifdefs is needed

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted to the original upstream. I instead changed the recipe_common.cuh following the other convention in quantize_transpose_vector_blockwise_fp4.cu L230

@alextmagro

alextmagro commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

Could you give a description of what you want to achieve with this PR? My understanding is that block fp8 quantization relies on some upstream kernels that will need to be adapted for AMD.

If you're just trying to enable the interface, I would argue that we should do this last, after we have a working quantization and GEMM path (and enabled and passing C++/Python tests).

@asdfvg123

Copy link
Copy Markdown
Author

@alextmagro
This PR is to enable only the quantization in the AMD gpus, not the GEMM. There are two kernels in the upstream which uses TMA for the quantization and does not uses TMA for the quantization. I guarded the kernels which uses TMA and used the non-TMA kernels to quantize for AMD.

I tested with
tests/pytorch/test_float8blockwisetensor.py
and it passes [175 passed / 32 xpassed / 5 warnings]

@alextmagro

Copy link
Copy Markdown
Contributor

@alextmagro This PR is to enable only the quantization in the AMD gpus, not the GEMM. There are two kernels in the upstream which uses TMA for the quantization and does not uses TMA for the quantization. I guarded the kernels which uses TMA and used the non-TMA kernels to quantize for AMD.

I tested with tests/pytorch/test_float8blockwisetensor.py and it passes [175 passed / 32 xpassed / 5 warnings]

OK, in that case we need to add the cpp blockwise tests to the CMake file, and the pytorch test file to ci/pytorch.sh.

…dant HIP guards, revert unnecessary common.h change
Comment thread ci/pytorch.sh
Comment thread tests/cpp/operator/test_cast_float8blockwise.cu Outdated
Comment thread tests/cpp/operator/test_cast_float8blockwise.cu Outdated
Comment thread tests/pytorch/test_float8blockwisetensor.py Outdated
Comment thread transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu Outdated
Comment thread transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu Outdated
constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches

#ifdef __HIP_PLATFORM_AMD__
constexpr int kThreadsPerBlock = 512; // Thread block size, 8 warps (wave64) in total

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there actual performance improvements for increasing the # of threads and the threads per warp? If not, we should use the already present values for now.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel expects 8 waves / block , so I increased the number of threads

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be platform specific I think -- probably need the gfx1250 guard here too

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this kernel would work with 4 waves/block as well, right? That might give better performance given the waves are twice as large as upstream.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added gfx1250 guard.
In line 779, there is a static_assert(num_iterations == 1 which requires 8 waves / block for gfx942 and gfx950. This is necessary and 4 waves/block fails to compile.

Comment thread transformer_engine/pytorch/quantization.py
@alextmagro

alextmagro commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

By the way, to run CI you need to add a CI level label. L3 is required before merging, L1 is for lighter testing, mostly sGPU tests, if you are midway through the ticket and expect to make more changes

Uploading image.png…

@asdfvg123 asdfvg123 added the ci-level 1 CI test level 1 label Jun 4, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


#ifdef __HIP_PLATFORM_AMD__
using WarpSyncMask = uint64_t;
constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFFFFFFFFFULL;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROCm should not use it. See how *_sync calls are guarded in other places

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the mask and use ROCm __shfl instead of __shfl_sync

}
}
// Reduce amax in the warp (32x32 tile)
#ifdef __HIP_PLATFORM_AMD__

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole this code is under #ifndef HIP_PLATFORM_AMD

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the dead branch

// const values configuration

#ifdef __HIP_PLATFORM_AMD__
constexpr size_t kThreadsPerWarp = 64;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is platform dependent.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed now guarded with gfx1250 for 32 threads

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use warpSize from hipruntime here, since kThreadsPerWarp is only needed for device code?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think warpSize is not constexpr anymore. Or it is ?

@asdfvg123 asdfvg123 Jun 15, 2026

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you are referring to

inline __device__ const struct {
  __device__ __attribute__((always_inline, const)) operator int() const noexcept {
    return __builtin_amdgcn_wavefrontsize();
  }
} warpSize{};

in amd_warp_functions.h
and this is not constexpr (assigned in the runtime) so cannot used.

transpose/multi_cast_transpose.cu
transpose/quantize_transpose_vector_blockwise.cu #CUDA-only
transpose/quantize_transpose_vector_blockwise.cu
transpose/quantize_transpose_square_blockwise.cu

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should stay in transformer_engine_cuda_arch_specific_sources

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment thread transformer_engine/common/common.h Outdated
Comment on lines +639 to +640
static constexpr float max = 448.0f;
static constexpr float max_inverse = 1.0 / max;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the value really used host size, it should be runtime detected. If it is only for host translation of GPU code (i.e.. results are discarded), you can keep 448, no extra ifdefs is needed

@asdfvg123

Copy link
Copy Markdown
Author

MI300 has 64KB of LDS which makes overflow when loading 128 * 128 FP32 data into LDS. I created a helper and branched the kernel. When loading FP32 data, the kernel loads 128 * 64 chunk of data and iterate to quantize. From the host's view, the kernel quantizes 128 * 128 elements.

# TODO replace with call to fp8.py when recipe added.
recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8)
if IS_HIP_EXTENSION:
recipe_available = get_device_compute_capability() >= (9, 0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe MI250 is (9,0) so this should be a > rather than a >=, or (9,4)

}
// Reduce amax in the warp (32x32 tile)
#ifdef __HIP_PLATFORM_AMD__
warp_tile_amax = warp_reduce_max_64(amax);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is platform dependent. Can we just manually unroll the function inline here, and guard the last shuffle for 64 thread platforms?

Also, instead of a reduce-broadcast we can do an all reduce across threads using __shfl_xor?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed. I changed to use _xor

#ifndef __HIP_PLATFORM_AMD__
const bool full_tile =
row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unneeded whitespace change

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches

#ifdef __HIP_PLATFORM_AMD__
constexpr int kThreadsPerBlock = 512; // Thread block size, 8 warps (wave64) in total

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be platform specific I think -- probably need the gfx1250 guard here too

constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp;

// gfx942 (MI300) has 64KB LDS; the full 128x128 fp32 staging tile overflows it.
#if defined(__HIP_PLATFORM_AMD__) && !defined(__gfx950__)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably should specifically check for gfx942 here and elsewhere instead of !defined(gfx950)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.
#if defined(__HIP_PLATFORM_AMD__) && (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx942__))
the code block has to be compiled in both host pass and gfx942 device pass and also should NOT be compiled in the Nvidia pass. The macro guards and is only true if AMD host or gfx942.

// const values configuration

#ifdef __HIP_PLATFORM_AMD__
constexpr size_t kThreadsPerWarp = 64;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use warpSize from hipruntime here, since kThreadsPerWarp is only needed for device code?

const int c_s = warp_in_chunk * num_smem_reads;
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s;
for (int chunk = 0; chunk < kNumChunks; ++chunk) {
__syncthreads();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably skip the syncthreads for the first iteration, also a pragma unroll might help here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed and added

Comment thread transformer_engine/pytorch/quantization.py
Comment thread transformer_engine/common/recipe/recipe_common.cuh
// const values configuration

#ifdef __HIP_PLATFORM_AMD__
constexpr size_t kThreadsPerWarp = 64;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think warpSize is not constexpr anymore. Or it is ?


#ifdef __HIP_PLATFORM_AMD__
__device__ __forceinline__ float blockwise_warp_reduce_max(float val) {
__device__ __forceinline__ float warp_reduce_max_64(float val) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 64?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is now removed.

// Step 2.3: Reduce amax
#pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
const float other_amax = __shfl_down_sync(mask, amax, delta);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use __shfl_down on ROCm

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed all *__sync from both AMD only path and AMD & Nvidia common path. I added guard and use non-sync in the AMD path.

using transformer_engine::detail::FP8BlockwiseRowwiseOption;

#ifdef __HIP_PLATFORM_AMD__
using WarpSyncMask = uint64_t;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review where it is used. Wavefront level primitives on ROcm should not use mask

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the unnecessary mask definitions together with using only non-sync in the AMD path. Reverted to the upstream.

if IS_HIP_EXTENSION:
return False, "FP8 block scaled gemm not yet supported for ROCm"
gpu_arch = get_device_compute_capability()
if gpu_arch >= (9, 0):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP8 starts from 9.4

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@asdfvg123 asdfvg123 force-pushed the yeonsoo/blockwise_fp8 branch from dc4c5fd to 70c35df Compare June 15, 2026 23:15
@asdfvg123 asdfvg123 requested review from alextmagro and ipanfilo June 16, 2026 00:00
// const values configuration

#if defined(__HIP_PLATFORM_AMD__) && !defined(__gfx1250__)
constexpr size_t kThreadsPerWarp = 64;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is kThreadsPerWarp only used by device code and not any dispatch functions?

// Reduce amax in the warp (32x32 tile)
#ifdef __HIP_PLATFORM_AMD__
#pragma unroll
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please clarify this logic with using xor?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants