gfx1250 mxfp8 gemm: loosen restrictions on K#627
Conversation
| // Check that K is compatible with the MXFP8 scale layout, and M/N are multiples of 16 | ||
| if (inputA->scaling_mode == NVTE_MXFP8_1D_SCALING || inputB->scaling_mode == NVTE_MXFP8_1D_SCALING) { | ||
| const bool is_gfx1250 = cuda::sm_arch() == 125; | ||
| const int required_k_multiple = is_gfx1250 ? 32 : 128; |
There was a problem hiding this comment.
Add a TODO here to change this for gfx950 after scale preswizzle is in hipblasLt.
| } | ||
| if (params.k % 128) { | ||
| GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; | ||
| const size_t required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128; |
There was a problem hiding this comment.
It should be under HIP_PLATFORM_AMD below
There was a problem hiding this comment.
What do you think of 3a7dd8f? Or do you prefer to move the whole if (use_mxfp8) part into the #ifdef __HIP_PLATFORM_AMD__ below?
There was a problem hiding this comment.
It looks like that even previous constrains were actually ROCm specific despite the idea that ROCm specific constrains are added below under ifdef.
Moreover, some ROCm constrains are more relaxed than naive ones (multiple of 16 vs multiple of 32) so we cannot have naive generic constrains here.
Saying that, and bearing in mind it is AMD originated test it might be not worth of efforts to separate ROCm and generic constrains so you may revert guarding or keep it as is. Sorry for confusion.
| } | ||
| if (params.k % 128) { | ||
| GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; | ||
| const size_t required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128; |
There was a problem hiding this comment.
It looks like that even previous constrains were actually ROCm specific despite the idea that ROCm specific constrains are added below under ifdef.
Moreover, some ROCm constrains are more relaxed than naive ones (multiple of 16 vs multiple of 32) so we cannot have naive generic constrains here.
Saying that, and bearing in mind it is AMD originated test it might be not worth of efforts to separate ROCm and generic constrains so you may revert guarding or keep it as is. Sorry for confusion.
| const bool use_mxfp8 = params.scaling_mode == NVTEScalingMode::NVTE_MXFP8_1D_SCALING; | ||
|
|
||
| cudaDeviceProp prop; | ||
| (void)cudaGetDeviceProperties(&prop, 0); |
There was a problem hiding this comment.
Can we use the get_arch function here to avoid calling this for every test?
| GTEST_SKIP() << "MXFP8 requires K to be a multiple of 128"; | ||
| size_t required_k_multiple = 128; | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| required_k_multiple = (prop.major == 12 && prop.minor == 5) ? 32 : 128; |
There was a problem hiding this comment.
nit: gfx1250 requires a multiple of the block size, not necessarily 32. I believe 16 may also be supported.
| NVTE_CHECK((k % required_k_multiple) == 0, | ||
| "GEMM K dimension must be multiple of ", required_k_multiple, | ||
| " for MXFP8 scaling (got K=", k, ")"); | ||
| NVTE_CHECK((m % 16) == 0, "GEMM M dimension must be multiple of 16 for MXFP8 scaling (got M=", m, ")"); |
There was a problem hiding this comment.
I think that hipblaslt supports arbitrary M/N for gfx1250?
Description
Loosen restrictions on K on gfx1250 mxfp8 gemm (K must be multiple of 32), confirmed with hipblaslt developers.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: