Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 69 additions & 45 deletions cub/cub/device/dispatch/tuning/tuning_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -873,63 +873,87 @@ struct policy_selector
// TODO(griwes): remove this field before policy_selector is publicly exposed
bool benchmark_match;

_CCCL_API constexpr auto get_warpspeed_policy(::cuda::arch_id arch) const
-> ::cuda::std::optional<scan_warpspeed_policy>
_CCCL_API constexpr auto get_sm100_fallback_warpspeed_policy() const -> scan_warpspeed_policy
{
if (arch >= ::cuda::arch_id::sm_100)
{
scan_warpspeed_policy warpspeed_policy{};
scan_warpspeed_policy warpspeed_policy{};

// TODO(bgruber): tune this
// TODO(bgruber): tune this
#if _CCCL_COMPILER(NVHPC)
// need to reduce the number of threads to <= 256, so each thread can use up to 255 registers. This avoids an
// error in ptxas, see also: https://github.com/NVIDIA/cccl/issues/7700.
warpspeed_policy.num_reduce_and_scan_warps = 2;
// need to reduce the number of threads to <= 256, so each thread can use up to 255 registers. This avoids an
// error in ptxas, see also: https://github.com/NVIDIA/cccl/issues/7700.
warpspeed_policy.num_reduce_and_scan_warps = 2;
#else // _CCCL_COMPILER(NVHPC)
warpspeed_policy.num_reduce_and_scan_warps = 4;
warpspeed_policy.num_reduce_and_scan_warps = 4;
#endif // _CCCL_COMPILER(NVHPC)

// TODO(bgruber): 5 is a bit better for complex<float>
warpspeed_policy.look_ahead_items_per_thread = accum_size == 2 ? 3 : 4;

// manual tuning based on cub.bench.scan.exclusive.sum.base
// 256 / sizeof(InputValueT) - 1 should minimize bank conflicts (and fits into 48KiB SMEM)
// 2-byte types and double needed special handling
auto items_per_thread = ::cuda::std::max(256 / (input_value_size == 2 ? 2 : accum_size) - 1, 1);
// TODO(bgruber): the special handling of double below is a LOT faster on B200, but exceeds 48KiB SMEM
// clang-format off
// | F64 | I32 | 72576 | 11.295 us | 2.44% | 11.917 us | 8.02% | 0.622 us | 5.50% | SLOW |
// | F64 | I32 | 1056384 | 16.162 us | 6.24% | 15.847 us | 5.57% | -0.315 us | -1.95% | SAME |
// | F64 | I32 | 16781184 | 65.696 us | 1.64% | 60.650 us | 3.37% | -5.046 us | -7.68% | FAST |
// | F64 | I32 | 268442496 | 863.896 us | 0.22% | 679.100 us | 0.93% | -184.796 us | -21.39% | FAST |
// | F64 | I32 | 1073745792 | 3.418 ms | 0.12% | 2.662 ms | 0.46% | -755.740 us | -22.11% | FAST |
// | F64 | I64 | 72576 | 12.301 us | 8.18% | 12.987 us | 5.75% | 0.686 us | 5.58% | SAME |
// | F64 | I64 | 1056384 | 16.775 us | 5.70% | 16.091 us | 6.14% | -0.684 us | -4.08% | SAME |
// | F64 | I64 | 16781184 | 66.970 us | 1.41% | 58.024 us | 3.17% | -8.946 us | -13.36% | FAST |
// | F64 | I64 | 268442496 | 863.826 us | 0.23% | 676.465 us | 0.98% | -187.360 us | -21.69% | FAST |
// | F64 | I64 | 1073745792 | 3.419 ms | 0.11% | 2.664 ms | 0.48% | -755.409 us | -22.09% | FAST |
// | F64 | I64 | 4294975104 | 13.641 ms | 0.05% | 10.575 ms | 0.24% | -3065.815 us | -22.48% | FAST |
// clang-format on
// (256 / (sizeof(InputValueT) == 2 ? 2 : (::cuda::std::is_same_v<InputValueT, double> ? 4 : sizeof(AccumT))) -
// 1);

if (arch >= ::cuda::arch_id::sm_120 && operation_t == op_kind_t::other && is_arithmetic_type(input_type))
// TODO(bgruber): 5 is a bit better for complex<float>
warpspeed_policy.look_ahead_items_per_thread = accum_size == 2 ? 3 : 4;

// manual tuning based on cub.bench.scan.exclusive.sum.base
// 256 / sizeof(InputValueT) - 1 should minimize bank conflicts (and fits into 48KiB SMEM)
// 2-byte types and double needed special handling
warpspeed_policy.items_per_thread = ::cuda::std::max(256 / (input_value_size == 2 ? 2 : accum_size) - 1, 1);
// TODO(bgruber): the special handling of double below is a LOT faster on B200, but exceeds 48KiB SMEM
// clang-format off
// | F64 | I64 | 72576 | 12.301 us | 8.18% | 12.987 us | 5.75% | 0.686 us | 5.58% | SAME |
// | F64 | I64 | 1056384 | 16.775 us | 5.70% | 16.091 us | 6.14% | -0.684 us | -4.08% | SAME |
// | F64 | I64 | 16781184 | 66.970 us | 1.41% | 58.024 us | 3.17% | -8.946 us | -13.36% | FAST |
// | F64 | I64 | 268442496 | 863.826 us | 0.23% | 676.465 us | 0.98% | -187.360 us | -21.69% | FAST |
// | F64 | I64 | 1073745792 | 3.419 ms | 0.11% | 2.664 ms | 0.48% | -755.409 us | -22.09% | FAST |
// | F64 | I64 | 4294975104 | 13.641 ms | 0.05% | 10.575 ms | 0.24% | -3065.815 us | -22.48% | FAST |
// clang-format on
// (256 / (sizeof(InputValueT) == 2 ? 2 : (::cuda::std::is_same_v<InputValueT, double> ? 4 : sizeof(AccumT))) -
// 1);

return warpspeed_policy;
}

_CCCL_API constexpr auto get_sm120_fallback_warpspeed_policy() const -> scan_warpspeed_policy
{
auto policy = get_sm100_fallback_warpspeed_policy();
if (operation_t == op_kind_t::other && is_arithmetic_type(input_type))
{
if (input_value_size == 4 || input_value_size == 8)
{
if (input_value_size == 4 || input_value_size == 8)
{
items_per_thread = 127;
}
else
policy.items_per_thread = 127;
}
else
{
policy.items_per_thread = ::cuda::std::min(policy.items_per_thread, input_value_size <= 2 ? 63 : 127);
}
}
return policy;
}

_CCCL_API constexpr auto get_warpspeed_policy(::cuda::arch_id arch) const
-> ::cuda::std::optional<scan_warpspeed_policy>
{
if (arch >= ::cuda::arch_id::sm_120)
{
return get_sm120_fallback_warpspeed_policy();
}
if (arch >= ::cuda::arch_id::sm_100)
{
// tunings from cub/benchmarks/bench/scan/exclusive/sum.warpspeed.cu
if (operation_t == op_kind_t::plus && accum_is_primitive_or_trivially_copy_constructible)
{
switch (input_value_size)
{
items_per_thread = ::cuda::std::min(items_per_thread, input_value_size <= 2 ? 63 : 127);
case 1:
// wrps_4.lbi_8.ipt_160 () 1.264254 1.264254 1.264254 1.264254
return scan_warpspeed_policy{4, 8, 160 - 1};
case 2:
// wrps_6.lbi_2.ipt_96 () 1.167633 1.167633 1.167633 1.167633
return scan_warpspeed_policy{6, 2, 96 - 1};

// TODO(bgruber): tune for more data types
default:
break;
}
}

warpspeed_policy.items_per_thread = items_per_thread;

return warpspeed_policy;
return get_sm100_fallback_warpspeed_policy();
}

return {};
}

Expand Down
Loading