diff --git a/backends/cadence/aot/replace_ops.py b/backends/cadence/aot/replace_ops.py index 4b60feb2121..67a2aaae4fc 100644 --- a/backends/cadence/aot/replace_ops.py +++ b/backends/cadence/aot/replace_ops.py @@ -2021,6 +2021,13 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: if any(d != 1 for d in dilation): return False + # When channel_last=True (NHWC layout), im2row rearranges data from + # kp-major (NHWC natural order) to channel-major output layout. + # A simple view_copy cannot perform this data rearrangement. + channel_last = node.args[6] if len(node.args) > 6 else False + if channel_last: + return False + # im2row works on 3D or 4D tensors. # Output shape[1:-1] will be unit if input spatial dimensions are the same as kernel spatial dimensions. output_shape = node.meta["val"].shape diff --git a/backends/cadence/aot/tests/test_ref_implementations.py b/backends/cadence/aot/tests/test_ref_implementations.py index 29d74258ed4..f089e36d4d5 100644 --- a/backends/cadence/aot/tests/test_ref_implementations.py +++ b/backends/cadence/aot/tests/test_ref_implementations.py @@ -2270,6 +2270,44 @@ def test_avg_pool2d( dtype=torch.float32, ), ), + # Multi-channel input, 2x2 kernel, stride 1, no padding, NHWC. + # Same channel values as nchw_multi_channel above, just laid out + # in NHWC order. Expected output is byte-for-byte identical to + # the NCHW case — this asserts that NHWC im2row produces the + # channel-major [c][kp] layout (matching torch.nn.functional.unfold + # after NHWC->NCHW conversion). A [kp][c] layout (the prior bug) + # would instead produce [1, 10, 2, 11, 4, 13, 5, 14, ...]. + ( + "nhwc_multi_channel", + torch.tensor( + [ + [ + [[1, 10], [2, 11], [3, 12]], + [[4, 13], [5, 14], [6, 15]], + [[7, 16], [8, 17], [9, 18]], + ] + ], + dtype=torch.float32, + ), # (N=1, H=3, W=3, C=2) + (2, 2), + (1, 1), + (0, 0), + (1, 1), + None, + True, # channel_last + False, + torch.tensor( + [ + [ + [1, 2, 4, 5, 10, 11, 13, 14], + [2, 3, 5, 6, 11, 12, 14, 15], + [4, 5, 7, 8, 13, 14, 16, 17], + [5, 6, 8, 9, 14, 15, 17, 18], + ], + ], + dtype=torch.float32, + ), + ), # Multi-channel input and multi-channel zero-point ( "nchw_multi_channel_and_zero_point_no_padding", diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 170da6deb09..6b1c2a2396c 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -2103,6 +2103,42 @@ def test_replace_linear_like_conv(self) -> None: count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 1 ) + def test_no_replace_for_channel_last(self) -> None: + # NHWC im2row rearranges data from kp-major (NHWC natural order) to + # channel-major output layout — it is not a no-op view. The pass must + # not elide it to view_copy even when shape conditions would otherwise + # allow replacement (kernel == input spatial dims). + in_h, in_w = 13, 15 + x = torch.randn(1, in_h, in_w, 3) # NHWC + pad_value = torch.tensor(0, dtype=torch.int32) + channels_last = True + gm = single_op_builder( + placeholders=(x, pad_value), + op=exir_ops.edge.cadence.im2row.default, + args=(x, (in_h, in_w), (1, 1), (0, 0), (1, 1), pad_value, channels_last), + ) + gm = ExportPass().call(gm).graph_module + self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) + self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) + + gm_before = copy.deepcopy(gm) + + p = ReplaceIm2RowWithViewPass() + result = p.call(gm) + self.assertFalse(result.modified) + gm_after_replacement = result.graph_module + + inputs = [x, pad_value] + validate(gm_before, gm_after_replacement, inputs, "ReplaceIm2RowWithViewPass") + + # No replacement: im2row remains, no new view_copy. + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 1 + ) + self.assertEqual( + count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 0 + ) + class TestReplaceConvWithChannelLastConvPass(unittest.TestCase): def create_conv1d_graphmodule( diff --git a/backends/cadence/hifi/operators/op_im2row_out.cpp b/backends/cadence/hifi/operators/op_im2row_out.cpp index 0ff977c471c..413835a0abc 100644 --- a/backends/cadence/hifi/operators/op_im2row_out.cpp +++ b/backends/cadence/hifi/operators/op_im2row_out.cpp @@ -61,34 +61,32 @@ __attribute__((always_inline)) void im2row_( // array of size (out_height * out_width) x channels_col const int32_t channels_col = channels * kernel_h * kernel_w; - // If the layout is NHWC, we can copy 'channels' worth of contiguous data - // points when performing im2row. + // If the layout is NHWC, the input data is contiguous per-pixel (H, W, C). + // The output layout must match torch.nn.functional.unfold, which is [c][kp]: + // output[c * num_kp + kp] for each output position. if (channels_last) { + const int32_t num_kp = kernel_h * kernel_w; // Iterate over the output domain for (int _h = 0; _h < out_height; ++_h) { for (int _w = 0; _w < out_width; ++_w) { int32_t i_col = _h * out_width + _w; - // Each point in the output domain is the result of applying a filter of - // size kernel_h x kernel_w x channels on the input. But since channels - // is contiguous, we will not explicitly have a loop for it. for (int _kh = 0; _kh < kernel_h; ++_kh) { int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; for (int _kw = 0; _kw < kernel_w; ++_kw) { int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; + int32_t kp = _kh * kernel_w + _kw; - // h_im and w_im are the actual height and width coordinates of the - // input tensor from where we need to copy 'channels' points. - const T* __restrict__ slice_im = - data_im + (h_im * width + w_im) * channels; - T* __restrict__ slice_col = data_col + i_col * channels_col + - (_kh * kernel_w + _kw) * channels; - // If the coordinates were within the input domain, we copy - // 'channels' contiguous values. Otherwise we will fill the output - // with 0's. if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { - std::memcpy(slice_col, slice_im, channels * sizeof(T)); + const T* __restrict__ pixel = + data_im + (h_im * width + w_im) * channels; + for (int _c = 0; _c < channels; ++_c) { + data_col[i_col * channels_col + _c * num_kp + kp] = pixel[_c]; + } } else { - std::fill_n(slice_col, channels, T(in_zero_point)); + for (int _c = 0; _c < channels; ++_c) { + data_col[i_col * channels_col + _c * num_kp + kp] = + static_cast(in_zero_point); + } } } } diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_im2row.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_im2row.c index 7008ee58f0a..8d3a3a1c506 100644 --- a/backends/cadence/hifi/third-party/nnlib/xa_nn_im2row.c +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_im2row.c @@ -18,53 +18,31 @@ WORD32 xa_nn_im2row_quantized( WORD8 *__restrict__ data_col, WORD32 channels_last) { const WORD32 channels_col = channels * kernel_h * kernel_w; - // If the layout is NHWC, we can copy 'channels' worth of contiguous data - // points when performing im2row. + // If the layout is NHWC, the input data is contiguous per-pixel (H, W, C). + // The output layout must match torch.nn.functional.unfold, which is [c][kp]: + // output[c * num_kp + kp] for each output position. if (channels_last) { + const int32_t num_kp = kernel_h * kernel_w; // Iterate over the output domain for (int _h = 0; _h < out_height; ++_h) { for (int _w = 0; _w < out_width; ++_w) { int32_t i_col = _h * out_width + _w; - // Each point in the output domain is the result of applying a filter of - // size kernel_h x kernel_w x channels on the input. But since channels - // is contiguous, we will not explicitly have a loop for it. for (int _kh = 0; _kh < kernel_h; ++_kh) { int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h; for (int _kw = 0; _kw < kernel_w; ++_kw) { int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w; + int32_t kp = _kh * kernel_w + _kw; - // h_im and w_im are the actual height and width coordinates of the - // input tensor from where we need to copy 'channels' points. - const int8_t *__restrict__ slice_im = - data_im + (h_im * width + w_im) * channels; - int8_t *__restrict__ slice_col = data_col + i_col * channels_col + - (_kh * kernel_w + _kw) * channels; - // If the coordinates were within the input domain, we copy - // 'channels' contiguous values. Otherwise we will fill the output - // with 0's. if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { - const ae_int32x2 *pae_inp = (const ae_int32x2 *)slice_im; - ae_int32x2 *pae_out = (ae_int32x2 *)slice_col; - ae_valign inp_a, out_a; - inp_a = AE_LA64_PP(pae_inp); - out_a = AE_ZALIGN64(); - - ae_int32x2 d0; - for (int ic = 0; ic < channels >> 3; ic++) { - AE_LA32X2_IP(d0, inp_a, pae_inp); - AE_SA32X2_IP(d0, out_a, pae_out); - } - AE_SA64POS_FP(out_a, pae_out); - - int remainder = channels & 7; - int8_t *ptmp_in = (int8_t *)pae_inp; - int8_t *ptmp_out = (int8_t *)pae_out; - for (int ic = 0; ic < remainder; ic++) { - *ptmp_out++ = *ptmp_in++; + const int8_t *__restrict__ pixel = + data_im + (h_im * width + w_im) * channels; + for (int _c = 0; _c < channels; ++_c) { + data_col[i_col * channels_col + _c * num_kp + kp] = pixel[_c]; } } else { - for (int i = 0; i < channels; i++) { - slice_col[i] = (int8_t)(in_zero_point); + for (int _c = 0; _c < channels; ++_c) { + data_col[i_col * channels_col + _c * num_kp + kp] = + (int8_t)(in_zero_point); } } }