From f86be17a635c3716bafadbeb904eeb4bed80033d Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Fri, 19 Dec 2025 10:05:12 -0600 Subject: [PATCH 1/2] Add early return for element tile calculation combine logic and ensure we're also exiting for cases where dims are less than 1 --- src/targets/gpu/compile_gen.cpp | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/targets/gpu/compile_gen.cpp b/src/targets/gpu/compile_gen.cpp index ced8ab03bae..74ea67f97de 100644 --- a/src/targets/gpu/compile_gen.cpp +++ b/src/targets/gpu/compile_gen.cpp @@ -230,12 +230,18 @@ tile tile::elements(const std::vector& inputs, std::size_t noutputs) if(nargs != 1) return {}; - const auto& s = inputs.front(); - auto dim1 = compute_tile_factor(s.lens()[result.axis]); - auto dim2 = compute_tile_factor(s.lens().back(), 4096 / dim1); - if(dim1 == 1 or dim2 == 1) + const auto& s = inputs.front(); + auto dim1 = compute_tile_factor(s.lens()[result.axis]); + auto dim2 = compute_tile_factor(s.lens().back(), 4096 / dim1); + auto tile_size = dim1 * dim2; + // equivalent to dim1 * (dim2 + 1) to avoid bank conflicts + auto tile_bytes = (tile_size + dim1) * s.type_size(); + + if(dim1 <= 1 or dim2 <= 1 or tile_bytes > 65536) return {}; + result.ntiles = s.elements() / tile_size; + result.inner = s.lens(); std::fill(result.inner.begin(), result.inner.end(), 1); result.inner[result.axis] = dim1; @@ -245,13 +251,6 @@ tile tile::elements(const std::vector& inputs, std::size_t noutputs) result.outer[result.axis] /= dim1; result.outer.back() /= dim2; - auto tile_size = dim1 * dim2; - result.ntiles = s.elements() / tile_size; - // equivalent to dim1 * (dim2 + 1) to avoid bank conflicts - auto tile_bytes = (tile_size + dim1) * s.type_size(); - if(tile_bytes > 65536) - return {}; - result.block_size = std::min(256, integer_divide_ceil(tile_size / 4, 64) * 64); return result; } From dc86a84b3d6e03420d7831cd22ac7a883d29c8f1 Mon Sep 17 00:00:00 2001 From: Ted Themistokleous Date: Mon, 22 Dec 2025 20:15:01 +0000 Subject: [PATCH 2/2] format --- src/targets/gpu/compile_gen.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/targets/gpu/compile_gen.cpp b/src/targets/gpu/compile_gen.cpp index 74ea67f97de..2215b86a604 100644 --- a/src/targets/gpu/compile_gen.cpp +++ b/src/targets/gpu/compile_gen.cpp @@ -230,17 +230,17 @@ tile tile::elements(const std::vector& inputs, std::size_t noutputs) if(nargs != 1) return {}; - const auto& s = inputs.front(); - auto dim1 = compute_tile_factor(s.lens()[result.axis]); - auto dim2 = compute_tile_factor(s.lens().back(), 4096 / dim1); - auto tile_size = dim1 * dim2; + const auto& s = inputs.front(); + auto dim1 = compute_tile_factor(s.lens()[result.axis]); + auto dim2 = compute_tile_factor(s.lens().back(), 4096 / dim1); + auto tile_size = dim1 * dim2; // equivalent to dim1 * (dim2 + 1) to avoid bank conflicts auto tile_bytes = (tile_size + dim1) * s.type_size(); if(dim1 <= 1 or dim2 <= 1 or tile_bytes > 65536) return {}; - result.ntiles = s.elements() / tile_size; + result.ntiles = s.elements() / tile_size; result.inner = s.lens(); std::fill(result.inner.begin(), result.inner.end(), 1);