enable blockwise FP8 quantization on rocm#609
Conversation
| # TODO replace with call to fp8.py when recipe added. | ||
| recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8) | ||
| if IS_HIP_EXTENSION: | ||
| recipe_available = get_device_compute_capability() >= (9, 0) |
There was a problem hiding this comment.
Wouldn't this be always True on ROCm TE?
There was a problem hiding this comment.
This test targets MI300 and MI350 so I set to (9,0)
There was a problem hiding this comment.
I believe MI250 is (9,0) so this should be a > rather than a >=, or (9,4)
| @@ -1 +1 @@ | |||
| /************************************************************************* | |||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include <cudaTypedefs.h> | ||
| #endif | ||
| #include <cuda_bf16.h> | ||
| #include <cuda_runtime.h> | ||
|
|
||
| #include <cfloat> | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include <cuda/barrier> | ||
| #endif | ||
|
|
||
| #include "common/common.h" | ||
| #include "common/recipe/recipe_common.cuh" | ||
| #include "common/util/cuda_runtime.h" | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include "common/util/ptx.cuh" | ||
| #endif |
There was a problem hiding this comment.
These #includes should be already disabled via hipify, so probably no need for the #ifndefs here.
| static constexpr float max = 448.0f; | ||
| static constexpr float max_inverse = 1.0 / max; |
There was a problem hiding this comment.
Is this change necessary? fp8e4m3 max depends on the device type on AMD.
There was a problem hiding this comment.
quantize_transpose_square_blockwise.cu and quantize_transpose_vector_blockwise.cu use
compute_scale_from_types<IType, fp8e4m3> for the first time, which exposed a latent bug in common.h
The #else branch of TypeExtrema<fp8e4m3> declared max as a static float,
This caused the constexpr static float max_finite_value initializer in TypeInfo in the same file to fail when the template was instantiated on the host.
The fix uses HIP_FP8_TYPE_FNUZ, used in hip_float8.h for selecting FNUZ at compile time, to make the host-pass branch constexpr as well.
There was a problem hiding this comment.
If the value really used host size, it should be runtime detected. If it is only for host translation of GPU code (i.e.. results are discarded), you can keep 448, no extra ifdefs is needed
There was a problem hiding this comment.
reverted to the original upstream. I instead changed the recipe_common.cuh following the other convention in quantize_transpose_vector_blockwise_fp4.cu L230
|
Could you give a description of what you want to achieve with this PR? My understanding is that block fp8 quantization relies on some upstream kernels that will need to be adapted for AMD. If you're just trying to enable the interface, I would argue that we should do this last, after we have a working quantization and GEMM path (and enabled and passing C++/Python tests). |
|
@alextmagro I tested with |
OK, in that case we need to add the cpp blockwise tests to the CMake file, and the pytorch test file to ci/pytorch.sh. |
…dant HIP guards, revert unnecessary common.h change
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| using WarpSyncMask = uint64_t; | ||
| constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFFFFFFFFFULL; |
There was a problem hiding this comment.
ROCm should not use it. See how *_sync calls are guarded in other places
There was a problem hiding this comment.
removed the mask and use ROCm __shfl instead of __shfl_sync
| } | ||
| } | ||
| // Reduce amax in the warp (32x32 tile) | ||
| #ifdef __HIP_PLATFORM_AMD__ |
There was a problem hiding this comment.
The whole this code is under #ifndef HIP_PLATFORM_AMD
| // const values configuration | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr size_t kThreadsPerWarp = 64; |
There was a problem hiding this comment.
It is platform dependent.
There was a problem hiding this comment.
fixed now guarded with gfx1250 for 32 threads
There was a problem hiding this comment.
Can we use warpSize from hipruntime here, since kThreadsPerWarp is only needed for device code?
There was a problem hiding this comment.
I think warpSize is not constexpr anymore. Or it is ?
There was a problem hiding this comment.
I assume you are referring to
inline __device__ const struct {
__device__ __attribute__((always_inline, const)) operator int() const noexcept {
return __builtin_amdgcn_wavefrontsize();
}
} warpSize{};
in amd_warp_functions.h
and this is not constexpr (assigned in the runtime) so cannot used.
| transpose/multi_cast_transpose.cu | ||
| transpose/quantize_transpose_vector_blockwise.cu #CUDA-only | ||
| transpose/quantize_transpose_vector_blockwise.cu | ||
| transpose/quantize_transpose_square_blockwise.cu |
There was a problem hiding this comment.
It should stay in transformer_engine_cuda_arch_specific_sources
| static constexpr float max = 448.0f; | ||
| static constexpr float max_inverse = 1.0 / max; |
There was a problem hiding this comment.
If the value really used host size, it should be runtime detected. If it is only for host translation of GPU code (i.e.. results are discarded), you can keep 448, no extra ifdefs is needed
|
MI300 has 64KB of LDS which makes overflow when loading 128 * 128 FP32 data into LDS. I created a helper and branched the kernel. When loading FP32 data, the kernel loads 128 * 64 chunk of data and iterate to quantize. From the host's view, the kernel quantizes 128 * 128 elements. |
| # TODO replace with call to fp8.py when recipe added. | ||
| recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8) | ||
| if IS_HIP_EXTENSION: | ||
| recipe_available = get_device_compute_capability() >= (9, 0) |
There was a problem hiding this comment.
I believe MI250 is (9,0) so this should be a > rather than a >=, or (9,4)
| // const values configuration | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr size_t kThreadsPerWarp = 64; |
There was a problem hiding this comment.
Can we use warpSize from hipruntime here, since kThreadsPerWarp is only needed for device code?
| const int c_s = warp_in_chunk * num_smem_reads; | ||
| size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s; | ||
| for (int chunk = 0; chunk < kNumChunks; ++chunk) { | ||
| __syncthreads(); |
There was a problem hiding this comment.
We can probably skip the syncthreads for the first iteration, also a pragma unroll might help here.
There was a problem hiding this comment.
Looking more closely, could we remove this syncthreads completely, and then do a double buffer for load_chunk_to_smem?
| // const values configuration | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr size_t kThreadsPerWarp = 64; |
There was a problem hiding this comment.
I think warpSize is not constexpr anymore. Or it is ?
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| __device__ __forceinline__ float blockwise_warp_reduce_max(float val) { | ||
| __device__ __forceinline__ float warp_reduce_max_64(float val) { |
| // Step 2.3: Reduce amax | ||
| #pragma unroll | ||
| for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) { | ||
| const float other_amax = __shfl_down_sync(mask, amax, delta); |
There was a problem hiding this comment.
I removed all *__sync from both AMD only path and AMD & Nvidia common path. I added guard and use non-sync in the AMD path.
| using transformer_engine::detail::FP8BlockwiseRowwiseOption; | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| using WarpSyncMask = uint64_t; |
There was a problem hiding this comment.
Review where it is used. Wavefront level primitives on ROcm should not use mask
There was a problem hiding this comment.
Removed the unnecessary mask definitions together with using only non-sync in the AMD path. Reverted to the upstream.
| if IS_HIP_EXTENSION: | ||
| return False, "FP8 block scaled gemm not yet supported for ROCm" | ||
| gpu_arch = get_device_compute_capability() | ||
| if gpu_arch >= (9, 0): |
dc4c5fd to
70c35df
Compare
| // const values configuration | ||
|
|
||
| #if defined(__HIP_PLATFORM_AMD__) && !defined(__gfx1250__) | ||
| constexpr size_t kThreadsPerWarp = 64; |
There was a problem hiding this comment.
is kThreadsPerWarp only used by device code and not any dispatch functions?
There was a problem hiding this comment.
It is used to compute NUM_THREADS_Y_IN_WARP in L71 for constexpr computation. And other than this, it is only used in the device code.
| // Reduce amax in the warp (32x32 tile) | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| #pragma unroll | ||
| for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) { |
There was a problem hiding this comment.
Can you please clarify this logic with using xor?
There was a problem hiding this comment.
The purpose is to share the amax value across the lanes in a wave.
shfl_xor(val, mask, width) computes: lane i gives val to lane i XOR mask and returns lane i XOR mask val.
So in the first iteration (delta = 64 / 2 = 32) lane 0 exchanges amax with lane 32 and lane 1 exchanges amax with lane 33 so on.
in the next iteration (delta = 16) lane 0 exchanges the amax (accumulated amax from the previous step) with lane 16 and lane 1 exchanges the amax with lane 17 and so on.
At the end of the loop, all lanes compute the amax across the wave. So no separate broadcast is needed.
There was a problem hiding this comment.
For reference: using xor for warp level all reduce is faster than shuffling down + broadcast since we skip the broadcast instructions, so is generally best practice.
Upstream has subwarp_reduce_max_broadcast and warp_reduce_max_broadcast using the slower reduce + broadcast implementation, so I will put in a PR to fix those up. We can then use those directly here?
Description
Please include a brief summary of the changes, relevant motivation and context.
Enable blockwise FP8 quantization on rocm
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
remove HIP guard in quantization.py
guard kernels using TMA in quantization.
add branch to handle rocm for different threads per wave
Checklist: