enable blockwise FP8 quantization on rocm#609
Conversation
| # TODO replace with call to fp8.py when recipe added. | ||
| recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8) | ||
| if IS_HIP_EXTENSION: | ||
| recipe_available = get_device_compute_capability() >= (9, 0) |
There was a problem hiding this comment.
Wouldn't this be always True on ROCm TE?
There was a problem hiding this comment.
This test targets MI300 and MI350 so I set to (9,0)
There was a problem hiding this comment.
I believe MI250 is (9,0) so this should be a > rather than a >=, or (9,4)
| @@ -1 +1 @@ | |||
| /************************************************************************* | |||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include <cudaTypedefs.h> | ||
| #endif | ||
| #include <cuda_bf16.h> | ||
| #include <cuda_runtime.h> | ||
|
|
||
| #include <cfloat> | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include <cuda/barrier> | ||
| #endif | ||
|
|
||
| #include "common/common.h" | ||
| #include "common/recipe/recipe_common.cuh" | ||
| #include "common/util/cuda_runtime.h" | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include "common/util/ptx.cuh" | ||
| #endif |
There was a problem hiding this comment.
These #includes should be already disabled via hipify, so probably no need for the #ifndefs here.
| static constexpr float max = 448.0f; | ||
| static constexpr float max_inverse = 1.0 / max; |
There was a problem hiding this comment.
Is this change necessary? fp8e4m3 max depends on the device type on AMD.
There was a problem hiding this comment.
quantize_transpose_square_blockwise.cu and quantize_transpose_vector_blockwise.cu use
compute_scale_from_types<IType, fp8e4m3> for the first time, which exposed a latent bug in common.h
The #else branch of TypeExtrema<fp8e4m3> declared max as a static float,
This caused the constexpr static float max_finite_value initializer in TypeInfo in the same file to fail when the template was instantiated on the host.
The fix uses HIP_FP8_TYPE_FNUZ, used in hip_float8.h for selecting FNUZ at compile time, to make the host-pass branch constexpr as well.
There was a problem hiding this comment.
If the value really used host size, it should be runtime detected. If it is only for host translation of GPU code (i.e.. results are discarded), you can keep 448, no extra ifdefs is needed
There was a problem hiding this comment.
reverted to the original upstream. I instead changed the recipe_common.cuh following the other convention in quantize_transpose_vector_blockwise_fp4.cu L230
|
Could you give a description of what you want to achieve with this PR? My understanding is that block fp8 quantization relies on some upstream kernels that will need to be adapted for AMD. If you're just trying to enable the interface, I would argue that we should do this last, after we have a working quantization and GEMM path (and enabled and passing C++/Python tests). |
|
@alextmagro I tested with |
OK, in that case we need to add the cpp blockwise tests to the CMake file, and the pytorch test file to ci/pytorch.sh. |
…dant HIP guards, revert unnecessary common.h change
| constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr int kThreadsPerBlock = 512; // Thread block size, 8 warps (wave64) in total |
There was a problem hiding this comment.
Are there actual performance improvements for increasing the # of threads and the threads per warp? If not, we should use the already present values for now.
There was a problem hiding this comment.
The kernel expects 8 waves / block , so I increased the number of threads
There was a problem hiding this comment.
This would be platform specific I think -- probably need the gfx1250 guard here too
There was a problem hiding this comment.
I think this kernel would work with 4 waves/block as well, right? That might give better performance given the waves are twice as large as upstream.
There was a problem hiding this comment.
added gfx1250 guard.
In line 779, there is a static_assert(num_iterations == 1 which requires 8 waves / block for gfx942 and gfx950. This is necessary and 4 waves/block fails to compile.
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| using WarpSyncMask = uint64_t; | ||
| constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFFFFFFFFFULL; |
There was a problem hiding this comment.
ROCm should not use it. See how *_sync calls are guarded in other places
There was a problem hiding this comment.
removed the mask and use ROCm __shfl instead of __shfl_sync
| } | ||
| } | ||
| // Reduce amax in the warp (32x32 tile) | ||
| #ifdef __HIP_PLATFORM_AMD__ |
There was a problem hiding this comment.
The whole this code is under #ifndef HIP_PLATFORM_AMD
| // const values configuration | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr size_t kThreadsPerWarp = 64; |
There was a problem hiding this comment.
It is platform dependent.
There was a problem hiding this comment.
fixed now guarded with gfx1250 for 32 threads
There was a problem hiding this comment.
Can we use warpSize from hipruntime here, since kThreadsPerWarp is only needed for device code?
There was a problem hiding this comment.
I think warpSize is not constexpr anymore. Or it is ?
There was a problem hiding this comment.
I assume you are referring to
inline __device__ const struct {
__device__ __attribute__((always_inline, const)) operator int() const noexcept {
return __builtin_amdgcn_wavefrontsize();
}
} warpSize{};
in amd_warp_functions.h
and this is not constexpr (assigned in the runtime) so cannot used.
| transpose/multi_cast_transpose.cu | ||
| transpose/quantize_transpose_vector_blockwise.cu #CUDA-only | ||
| transpose/quantize_transpose_vector_blockwise.cu | ||
| transpose/quantize_transpose_square_blockwise.cu |
There was a problem hiding this comment.
It should stay in transformer_engine_cuda_arch_specific_sources
| static constexpr float max = 448.0f; | ||
| static constexpr float max_inverse = 1.0 / max; |
There was a problem hiding this comment.
If the value really used host size, it should be runtime detected. If it is only for host translation of GPU code (i.e.. results are discarded), you can keep 448, no extra ifdefs is needed
|
MI300 has 64KB of LDS which makes overflow when loading 128 * 128 FP32 data into LDS. I created a helper and branched the kernel. When loading FP32 data, the kernel loads 128 * 64 chunk of data and iterate to quantize. From the host's view, the kernel quantizes 128 * 128 elements. |
| # TODO replace with call to fp8.py when recipe added. | ||
| recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8) | ||
| if IS_HIP_EXTENSION: | ||
| recipe_available = get_device_compute_capability() >= (9, 0) |
There was a problem hiding this comment.
I believe MI250 is (9,0) so this should be a > rather than a >=, or (9,4)
| } | ||
| // Reduce amax in the warp (32x32 tile) | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| warp_tile_amax = warp_reduce_max_64(amax); |
There was a problem hiding this comment.
This is platform dependent. Can we just manually unroll the function inline here, and guard the last shuffle for 64 thread platforms?
Also, instead of a reduce-broadcast we can do an all reduce across threads using __shfl_xor?
| #ifndef __HIP_PLATFORM_AMD__ | ||
| const bool full_tile = | ||
| row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; | ||
|
|
There was a problem hiding this comment.
unneeded whitespace change
| constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr int kThreadsPerBlock = 512; // Thread block size, 8 warps (wave64) in total |
There was a problem hiding this comment.
This would be platform specific I think -- probably need the gfx1250 guard here too
| constexpr int kNumWarps = kThreadsPerBlock / kThreadsPerWarp; | ||
|
|
||
| // gfx942 (MI300) has 64KB LDS; the full 128x128 fp32 staging tile overflows it. | ||
| #if defined(__HIP_PLATFORM_AMD__) && !defined(__gfx950__) |
There was a problem hiding this comment.
probably should specifically check for gfx942 here and elsewhere instead of !defined(gfx950)
There was a problem hiding this comment.
fixed.
#if defined(__HIP_PLATFORM_AMD__) && (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx942__))
the code block has to be compiled in both host pass and gfx942 device pass and also should NOT be compiled in the Nvidia pass. The macro guards and is only true if AMD host or gfx942.
| // const values configuration | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr size_t kThreadsPerWarp = 64; |
There was a problem hiding this comment.
Can we use warpSize from hipruntime here, since kThreadsPerWarp is only needed for device code?
| const int c_s = warp_in_chunk * num_smem_reads; | ||
| size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; | ||
| for (int chunk = 0; chunk < kNumChunks; ++chunk) { | ||
| __syncthreads(); |
There was a problem hiding this comment.
We can probably skip the syncthreads for the first iteration, also a pragma unroll might help here.
| // const values configuration | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr size_t kThreadsPerWarp = 64; |
There was a problem hiding this comment.
I think warpSize is not constexpr anymore. Or it is ?
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| __device__ __forceinline__ float blockwise_warp_reduce_max(float val) { | ||
| __device__ __forceinline__ float warp_reduce_max_64(float val) { |
| // Step 2.3: Reduce amax | ||
| #pragma unroll | ||
| for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { | ||
| const float other_amax = __shfl_down_sync(mask, amax, delta); |
There was a problem hiding this comment.
I removed all *__sync from both AMD only path and AMD & Nvidia common path. I added guard and use non-sync in the AMD path.
| using transformer_engine::detail::FP8BlockwiseRowwiseOption; | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| using WarpSyncMask = uint64_t; |
There was a problem hiding this comment.
Review where it is used. Wavefront level primitives on ROcm should not use mask
There was a problem hiding this comment.
Removed the unnecessary mask definitions together with using only non-sync in the AMD path. Reverted to the upstream.
| if IS_HIP_EXTENSION: | ||
| return False, "FP8 block scaled gemm not yet supported for ROCm" | ||
| gpu_arch = get_device_compute_capability() | ||
| if gpu_arch >= (9, 0): |
dc4c5fd to
70c35df
Compare
| // const values configuration | ||
|
|
||
| #if defined(__HIP_PLATFORM_AMD__) && !defined(__gfx1250__) | ||
| constexpr size_t kThreadsPerWarp = 64; |
There was a problem hiding this comment.
is kThreadsPerWarp only used by device code and not any dispatch functions?
| // Reduce amax in the warp (32x32 tile) | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| #pragma unroll | ||
| for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { |
There was a problem hiding this comment.
Can you please clarify this logic with using xor?
Description
Please include a brief summary of the changes, relevant motivation and context.
Enable blockwise FP8 quantization on rocm
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
remove HIP guard in quantization.py
guard kernels using TMA in quantization.
add branch to handle rocm for different threads per wave
Checklist: