diff --git a/src/targets/gpu/compile_gen.cpp b/src/targets/gpu/compile_gen.cpp index ced8ab03bae..2215b86a604 100644 --- a/src/targets/gpu/compile_gen.cpp +++ b/src/targets/gpu/compile_gen.cpp @@ -230,12 +230,18 @@ tile tile::elements(const std::vector& inputs, std::size_t noutputs) if(nargs != 1) return {}; - const auto& s = inputs.front(); - auto dim1 = compute_tile_factor(s.lens()[result.axis]); - auto dim2 = compute_tile_factor(s.lens().back(), 4096 / dim1); - if(dim1 == 1 or dim2 == 1) + const auto& s = inputs.front(); + auto dim1 = compute_tile_factor(s.lens()[result.axis]); + auto dim2 = compute_tile_factor(s.lens().back(), 4096 / dim1); + auto tile_size = dim1 * dim2; + // equivalent to dim1 * (dim2 + 1) to avoid bank conflicts + auto tile_bytes = (tile_size + dim1) * s.type_size(); + + if(dim1 <= 1 or dim2 <= 1 or tile_bytes > 65536) return {}; + result.ntiles = s.elements() / tile_size; + result.inner = s.lens(); std::fill(result.inner.begin(), result.inner.end(), 1); result.inner[result.axis] = dim1; @@ -245,13 +251,6 @@ tile tile::elements(const std::vector& inputs, std::size_t noutputs) result.outer[result.axis] /= dim1; result.outer.back() /= dim2; - auto tile_size = dim1 * dim2; - result.ntiles = s.elements() / tile_size; - // equivalent to dim1 * (dim2 + 1) to avoid bank conflicts - auto tile_bytes = (tile_size + dim1) * s.type_size(); - if(tile_bytes > 65536) - return {}; - result.block_size = std::min(256, integer_divide_ceil(tile_size / 4, 64) * 64); return result; }