Skip to content

enable blockwise FP8 quantization on rocm#609

Open
asdfvg123 wants to merge 9 commits into
devfrom
yeonsoo/blockwise_fp8
Open

enable blockwise FP8 quantization on rocm#609
asdfvg123 wants to merge 9 commits into
devfrom
yeonsoo/blockwise_fp8

Conversation

@asdfvg123

Copy link
Copy Markdown

Description

Please include a brief summary of the changes, relevant motivation and context.

Enable blockwise FP8 quantization on rocm

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

remove HIP guard in quantization.py
guard kernels using TMA in quantization.
add branch to handle rocm for different threads per wave

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

# TODO replace with call to fp8.py when recipe added.
recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8)
if IS_HIP_EXTENSION:
recipe_available = get_device_compute_capability() >= (9, 0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this be always True on ROCm TE?

@asdfvg123 asdfvg123 Jun 4, 2026

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test targets MI300 and MI350 so I set to (9,0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe MI250 is (9,0) so this should be a > rather than a >=, or (9,4)

@@ -1 +1 @@
/*************************************************************************

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs AMD copyright

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Comment on lines +8 to +24
#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
#endif
#include <cuda_bf16.h>
#include <cuda_runtime.h>

#include <cfloat>
#ifndef __HIP_PLATFORM_AMD__
#include <cuda/barrier>
#endif

#include "common/common.h"
#include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#ifndef __HIP_PLATFORM_AMD__
#include "common/util/ptx.cuh"
#endif

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These #includes should be already disabled via hipify, so probably no need for the #ifndefs here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment thread transformer_engine/common/common.h Outdated
Comment on lines +639 to +640
static constexpr float max = 448.0f;
static constexpr float max_inverse = 1.0 / max;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change necessary? fp8e4m3 max depends on the device type on AMD.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quantize_transpose_square_blockwise.cu and quantize_transpose_vector_blockwise.cu use
compute_scale_from_types<IType, fp8e4m3> for the first time, which exposed a latent bug in common.h

The #else branch of TypeExtrema<fp8e4m3> declared max as a static float,
This caused the constexpr static float max_finite_value initializer in TypeInfo in the same file to fail when the template was instantiated on the host.

The fix uses HIP_FP8_TYPE_FNUZ, used in hip_float8.h for selecting FNUZ at compile time, to make the host-pass branch constexpr as well.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the value really used host size, it should be runtime detected. If it is only for host translation of GPU code (i.e.. results are discarded), you can keep 448, no extra ifdefs is needed

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted to the original upstream. I instead changed the recipe_common.cuh following the other convention in quantize_transpose_vector_blockwise_fp4.cu L230

@alextmagro

alextmagro commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

Could you give a description of what you want to achieve with this PR? My understanding is that block fp8 quantization relies on some upstream kernels that will need to be adapted for AMD.

If you're just trying to enable the interface, I would argue that we should do this last, after we have a working quantization and GEMM path (and enabled and passing C++/Python tests).

@asdfvg123

Copy link
Copy Markdown
Author

@alextmagro
This PR is to enable only the quantization in the AMD gpus, not the GEMM. There are two kernels in the upstream which uses TMA for the quantization and does not uses TMA for the quantization. I guarded the kernels which uses TMA and used the non-TMA kernels to quantize for AMD.

I tested with
tests/pytorch/test_float8blockwisetensor.py
and it passes [175 passed / 32 xpassed / 5 warnings]

@alextmagro

Copy link
Copy Markdown
Contributor

@alextmagro This PR is to enable only the quantization in the AMD gpus, not the GEMM. There are two kernels in the upstream which uses TMA for the quantization and does not uses TMA for the quantization. I guarded the kernels which uses TMA and used the non-TMA kernels to quantize for AMD.

I tested with tests/pytorch/test_float8blockwisetensor.py and it passes [175 passed / 32 xpassed / 5 warnings]

OK, in that case we need to add the cpp blockwise tests to the CMake file, and the pytorch test file to ci/pytorch.sh.

…dant HIP guards, revert unnecessary common.h change
Comment thread ci/pytorch.sh
Comment thread tests/cpp/operator/test_cast_float8blockwise.cu Outdated
Comment thread tests/cpp/operator/test_cast_float8blockwise.cu Outdated
Comment thread tests/pytorch/test_float8blockwisetensor.py Outdated
Comment thread transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu Outdated
Comment thread transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu Outdated
Comment thread transformer_engine/pytorch/quantization.py
@alextmagro

alextmagro commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

By the way, to run CI you need to add a CI level label. L3 is required before merging, L1 is for lighter testing, mostly sGPU tests, if you are midway through the ticket and expect to make more changes

Uploading image.png…

@asdfvg123 asdfvg123 added the ci-level 1 CI test level 1 label Jun 4, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copyright

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added


#ifdef __HIP_PLATFORM_AMD__
using WarpSyncMask = uint64_t;
constexpr WarpSyncMask kFullWarpMask = 0xFFFFFFFFFFFFFFFFULL;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ROCm should not use it. See how *_sync calls are guarded in other places

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the mask and use ROCm __shfl instead of __shfl_sync

}
}
// Reduce amax in the warp (32x32 tile)
#ifdef __HIP_PLATFORM_AMD__

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole this code is under #ifndef HIP_PLATFORM_AMD

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed the dead branch

// const values configuration

#ifdef __HIP_PLATFORM_AMD__
constexpr size_t kThreadsPerWarp = 64;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is platform dependent.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed now guarded with gfx1250 for 32 threads

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use warpSize from hipruntime here, since kThreadsPerWarp is only needed for device code?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think warpSize is not constexpr anymore. Or it is ?

@asdfvg123 asdfvg123 Jun 15, 2026

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume you are referring to

inline __device__ const struct {
  __device__ __attribute__((always_inline, const)) operator int() const noexcept {
    return __builtin_amdgcn_wavefrontsize();
  }
} warpSize{};

in amd_warp_functions.h
and this is not constexpr (assigned in the runtime) so cannot used.

transpose/multi_cast_transpose.cu
transpose/quantize_transpose_vector_blockwise.cu #CUDA-only
transpose/quantize_transpose_vector_blockwise.cu
transpose/quantize_transpose_square_blockwise.cu

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should stay in transformer_engine_cuda_arch_specific_sources

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Comment thread transformer_engine/common/common.h Outdated
Comment on lines +639 to +640
static constexpr float max = 448.0f;
static constexpr float max_inverse = 1.0 / max;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the value really used host size, it should be runtime detected. If it is only for host translation of GPU code (i.e.. results are discarded), you can keep 448, no extra ifdefs is needed

@asdfvg123

Copy link
Copy Markdown
Author

MI300 has 64KB of LDS which makes overflow when loading 128 * 128 FP32 data into LDS. I created a helper and branched the kernel. When loading FP32 data, the kernel loads 128 * 64 chunk of data and iterate to quantize. From the host's view, the kernel quantizes 128 * 128 elements.

# TODO replace with call to fp8.py when recipe added.
recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8)
if IS_HIP_EXTENSION:
recipe_available = get_device_compute_capability() >= (9, 0)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe MI250 is (9,0) so this should be a > rather than a >=, or (9,4)

Comment thread transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu Outdated
Comment thread transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu Outdated
// const values configuration

#ifdef __HIP_PLATFORM_AMD__
constexpr size_t kThreadsPerWarp = 64;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use warpSize from hipruntime here, since kThreadsPerWarp is only needed for device code?

const int c_s = warp_in_chunk * num_smem_reads;
size_t r_g = static_cast<size_t>(blockIdx.y) * kTileDim + r_s;
for (int chunk = 0; chunk < kNumChunks; ++chunk) {
__syncthreads();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can probably skip the syncthreads for the first iteration, also a pragma unroll might help here.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed and added

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking more closely, could we remove this syncthreads completely, and then do a double buffer for load_chunk_to_smem?

Comment thread transformer_engine/pytorch/quantization.py
Comment thread transformer_engine/common/recipe/recipe_common.cuh
// const values configuration

#ifdef __HIP_PLATFORM_AMD__
constexpr size_t kThreadsPerWarp = 64;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think warpSize is not constexpr anymore. Or it is ?


#ifdef __HIP_PLATFORM_AMD__
__device__ __forceinline__ float blockwise_warp_reduce_max(float val) {
__device__ __forceinline__ float warp_reduce_max_64(float val) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 64?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is now removed.

// Step 2.3: Reduce amax
#pragma unroll
for (int delta = kNumThreadsStore / 2; delta > 0; delta /= 2) {
const float other_amax = __shfl_down_sync(mask, amax, delta);

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use __shfl_down on ROCm

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed all *__sync from both AMD only path and AMD & Nvidia common path. I added guard and use non-sync in the AMD path.

using transformer_engine::detail::FP8BlockwiseRowwiseOption;

#ifdef __HIP_PLATFORM_AMD__
using WarpSyncMask = uint64_t;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review where it is used. Wavefront level primitives on ROcm should not use mask

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the unnecessary mask definitions together with using only non-sync in the AMD path. Reverted to the upstream.

if IS_HIP_EXTENSION:
return False, "FP8 block scaled gemm not yet supported for ROCm"
gpu_arch = get_device_compute_capability()
if gpu_arch >= (9, 0):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FP8 starts from 9.4

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

@asdfvg123 asdfvg123 force-pushed the yeonsoo/blockwise_fp8 branch from dc4c5fd to 70c35df Compare June 15, 2026 23:15
@asdfvg123 asdfvg123 requested review from alextmagro and ipanfilo June 16, 2026 00:00
// const values configuration

#if defined(__HIP_PLATFORM_AMD__) && !defined(__gfx1250__)
constexpr size_t kThreadsPerWarp = 64;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is kThreadsPerWarp only used by device code and not any dispatch functions?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used to compute NUM_THREADS_Y_IN_WARP in L71 for constexpr computation. And other than this, it is only used in the device code.

// Reduce amax in the warp (32x32 tile)
#ifdef __HIP_PLATFORM_AMD__
#pragma unroll
for (int delta = kThreadsPerWarp / 2; delta > 0; delta /= 2) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please clarify this logic with using xor?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose is to share the amax value across the lanes in a wave.
shfl_xor(val, mask, width) computes: lane i gives val to lane i XOR mask and returns lane i XOR mask val.
So in the first iteration (delta = 64 / 2 = 32) lane 0 exchanges amax with lane 32 and lane 1 exchanges amax with lane 33 so on.
in the next iteration (delta = 16) lane 0 exchanges the amax (accumulated amax from the previous step) with lane 16 and lane 1 exchanges the amax with lane 17 and so on.
At the end of the loop, all lanes compute the amax across the wave. So no separate broadcast is needed.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reference: using xor for warp level all reduce is faster than shuffling down + broadcast since we skip the broadcast instructions, so is generally best practice.

Upstream has subwarp_reduce_max_broadcast and warp_reduce_max_broadcast using the slower reduce + broadcast implementation, so I will put in a PR to fix those up. We can then use those directly here?

@asdfvg123 asdfvg123 requested a review from ipanfilo June 17, 2026 17:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 1 CI test level 1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants