Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions tests/cpp/operator/test_cast_mxfp8_grouped.cu
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ void performTest(const ProcessingMethod processing_method,

NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size());

std::vector<size_t> dbias_logical_shape_vec= {num_tensors, cols};
std::vector<size_t> dbias_logical_shape_vec = {num_tensors, cols};
NVTEShape dbias_logical_shape_ = nvte_make_shape(dbias_logical_shape_vec.data(),
dbias_logical_shape_vec.size());

Expand Down Expand Up @@ -554,6 +554,11 @@ void performTest(const ProcessingMethod processing_method,
const double abs_tolerable_mismatches_limit = 0.0;
const double rel_tolerable_mismatches_limit = 0.0;

// Compare only allocated contiguous output range.
// In graph-safe mode logical shape may include trailing garbage beyond offsets_h.back().
const size_t compare_rows = 1;
const size_t compare_cols = elts_num;

if (rowwise) {
cudaMemcpy(out_data_rowwise_h.data(), out_data_rowwise_d, out_data_size, cudaMemcpyDeviceToHost);
cudaMemcpy(out_scales_rowwise_h.data(), out_scales_rowwise_d, rowwise_scales_size, cudaMemcpyDeviceToHost);
Expand All @@ -566,7 +571,8 @@ void performTest(const ProcessingMethod processing_method,
const size_t mismatches_elts = 32 * mismatches_scales;

compare_scaled_elts<OutputType>("rowwise_output", out_data_rowwise_ref.data(),
out_data_rowwise_h.data(), rows, cols, true, mismatches_elts);
out_data_rowwise_h.data(), compare_rows, compare_cols,
true, mismatches_elts);
}

if (colwise) {
Expand All @@ -581,7 +587,8 @@ void performTest(const ProcessingMethod processing_method,
const size_t mismatches_elts = 32 * mismatches_scales;

compare_scaled_elts<OutputType>("colwise_output", out_data_colwise_ref.data(),
out_data_colwise_h.data(), rows, cols, false, mismatches_elts);
out_data_colwise_h.data(), compare_rows, compare_cols,
false, mismatches_elts);
}

if (compute_dbias) {
Expand Down Expand Up @@ -651,6 +658,7 @@ std::vector<std::vector<size_t>> input_config = {
{VARYING_FIRST_DIM, 3, 1024,144, 128,384,512},
{VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512},
{VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304},
{VARYING_FIRST_DIM, 5, 16 * 4096,512, 128,256,384,1024,2304},
{VARYING_LAST_DIM, 3, 256,896, 128,256,512},
{VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256},
{VARYING_BOTH_DIMS, 2, 1,(256*128)+(512*640), 256,512, 128,640},
Expand Down
9 changes: 5 additions & 4 deletions transformer_engine/common/cast/core/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
const size_t tensor_id = blockIdx.y;
const size_t tensor_rows = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
? (first_logical_dim / num_tensors)
: first_dims_ptr[tensor_id];
: static_cast<size_t>(first_dims_ptr[tensor_id]);

const size_t rows = tensor_rows / chunk_dim_Y;
const size_t cols = last_logical_dim;

const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
? (tensor_id * (tensor_rows / chunk_dim_Y))
: (offsets_ptr[tensor_id] / cols / chunk_dim_Y);
const size_t dbias_in_offset_Y =
(shape_rep == ShapeRepresentation::SAME_BOTH_DIMS)
? (tensor_id * (tensor_rows / chunk_dim_Y))
: (static_cast<size_t>(offsets_ptr[tensor_id]) / cols / chunk_dim_Y);

const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;

Expand Down
Loading
Loading