Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2021,6 +2021,13 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
if any(d != 1 for d in dilation):
return False

# When channel_last=True (NHWC layout), im2row rearranges data from
# kp-major (NHWC natural order) to channel-major output layout.
# A simple view_copy cannot perform this data rearrangement.
channel_last = node.args[6] if len(node.args) > 6 else False
if channel_last:
return False

# im2row works on 3D or 4D tensors.
# Output shape[1:-1] will be unit if input spatial dimensions are the same as kernel spatial dimensions.
output_shape = node.meta["val"].shape
Expand Down
38 changes: 38 additions & 0 deletions backends/cadence/aot/tests/test_ref_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2270,6 +2270,44 @@ def test_avg_pool2d(
dtype=torch.float32,
),
),
# Multi-channel input, 2x2 kernel, stride 1, no padding, NHWC.
# Same channel values as nchw_multi_channel above, just laid out
# in NHWC order. Expected output is byte-for-byte identical to
# the NCHW case — this asserts that NHWC im2row produces the
# channel-major [c][kp] layout (matching torch.nn.functional.unfold
# after NHWC->NCHW conversion). A [kp][c] layout (the prior bug)
# would instead produce [1, 10, 2, 11, 4, 13, 5, 14, ...].
(
"nhwc_multi_channel",
torch.tensor(
[
[
[[1, 10], [2, 11], [3, 12]],
[[4, 13], [5, 14], [6, 15]],
[[7, 16], [8, 17], [9, 18]],
]
],
dtype=torch.float32,
), # (N=1, H=3, W=3, C=2)
(2, 2),
(1, 1),
(0, 0),
(1, 1),
None,
True, # channel_last
False,
torch.tensor(
[
[
[1, 2, 4, 5, 10, 11, 13, 14],
[2, 3, 5, 6, 11, 12, 14, 15],
[4, 5, 7, 8, 13, 14, 16, 17],
[5, 6, 8, 9, 14, 15, 17, 18],
],
],
dtype=torch.float32,
),
),
# Multi-channel input and multi-channel zero-point
(
"nchw_multi_channel_and_zero_point_no_padding",
Expand Down
36 changes: 36 additions & 0 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2103,6 +2103,42 @@ def test_replace_linear_like_conv(self) -> None:
count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 1
)

def test_no_replace_for_channel_last(self) -> None:
# NHWC im2row rearranges data from kp-major (NHWC natural order) to
# channel-major output layout — it is not a no-op view. The pass must
# not elide it to view_copy even when shape conditions would otherwise
# allow replacement (kernel == input spatial dims).
in_h, in_w = 13, 15
x = torch.randn(1, in_h, in_w, 3) # NHWC
pad_value = torch.tensor(0, dtype=torch.int32)
channels_last = True
gm = single_op_builder(
placeholders=(x, pad_value),
op=exir_ops.edge.cadence.im2row.default,
args=(x, (in_h, in_w), (1, 1), (0, 0), (1, 1), pad_value, channels_last),
)
gm = ExportPass().call(gm).graph_module
self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)
self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0)

gm_before = copy.deepcopy(gm)

p = ReplaceIm2RowWithViewPass()
result = p.call(gm)
self.assertFalse(result.modified)
gm_after_replacement = result.graph_module

inputs = [x, pad_value]
validate(gm_before, gm_after_replacement, inputs, "ReplaceIm2RowWithViewPass")

# No replacement: im2row remains, no new view_copy.
self.assertEqual(
count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 1
)
self.assertEqual(
count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 0
)


class TestReplaceConvWithChannelLastConvPass(unittest.TestCase):
def create_conv1d_graphmodule(
Expand Down
30 changes: 14 additions & 16 deletions backends/cadence/hifi/operators/op_im2row_out.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,34 +61,32 @@ __attribute__((always_inline)) void im2row_(
// array of size (out_height * out_width) x channels_col
const int32_t channels_col = channels * kernel_h * kernel_w;

// If the layout is NHWC, we can copy 'channels' worth of contiguous data
// points when performing im2row.
// If the layout is NHWC, the input data is contiguous per-pixel (H, W, C).
// The output layout must match torch.nn.functional.unfold, which is [c][kp]:
// output[c * num_kp + kp] for each output position.
if (channels_last) {
const int32_t num_kp = kernel_h * kernel_w;
// Iterate over the output domain
for (int _h = 0; _h < out_height; ++_h) {
for (int _w = 0; _w < out_width; ++_w) {
int32_t i_col = _h * out_width + _w;
// Each point in the output domain is the result of applying a filter of
// size kernel_h x kernel_w x channels on the input. But since channels
// is contiguous, we will not explicitly have a loop for it.
for (int _kh = 0; _kh < kernel_h; ++_kh) {
int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h;
for (int _kw = 0; _kw < kernel_w; ++_kw) {
int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w;
int32_t kp = _kh * kernel_w + _kw;

// h_im and w_im are the actual height and width coordinates of the
// input tensor from where we need to copy 'channels' points.
const T* __restrict__ slice_im =
data_im + (h_im * width + w_im) * channels;
T* __restrict__ slice_col = data_col + i_col * channels_col +
(_kh * kernel_w + _kw) * channels;
// If the coordinates were within the input domain, we copy
// 'channels' contiguous values. Otherwise we will fill the output
// with 0's.
if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
std::memcpy(slice_col, slice_im, channels * sizeof(T));
const T* __restrict__ pixel =
data_im + (h_im * width + w_im) * channels;
for (int _c = 0; _c < channels; ++_c) {
data_col[i_col * channels_col + _c * num_kp + kp] = pixel[_c];
}
} else {
std::fill_n(slice_col, channels, T(in_zero_point));
for (int _c = 0; _c < channels; ++_c) {
data_col[i_col * channels_col + _c * num_kp + kp] =
static_cast<T>(in_zero_point);
}
}
}
}
Expand Down
46 changes: 12 additions & 34 deletions backends/cadence/hifi/third-party/nnlib/xa_nn_im2row.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,53 +18,31 @@ WORD32 xa_nn_im2row_quantized(
WORD8 *__restrict__ data_col, WORD32 channels_last) {
const WORD32 channels_col = channels * kernel_h * kernel_w;

// If the layout is NHWC, we can copy 'channels' worth of contiguous data
// points when performing im2row.
// If the layout is NHWC, the input data is contiguous per-pixel (H, W, C).
// The output layout must match torch.nn.functional.unfold, which is [c][kp]:
// output[c * num_kp + kp] for each output position.
if (channels_last) {
const int32_t num_kp = kernel_h * kernel_w;
// Iterate over the output domain
for (int _h = 0; _h < out_height; ++_h) {
for (int _w = 0; _w < out_width; ++_w) {
int32_t i_col = _h * out_width + _w;
// Each point in the output domain is the result of applying a filter of
// size kernel_h x kernel_w x channels on the input. But since channels
// is contiguous, we will not explicitly have a loop for it.
for (int _kh = 0; _kh < kernel_h; ++_kh) {
int32_t h_im = _h * stride_h - pad_h + _kh * dilation_h;
for (int _kw = 0; _kw < kernel_w; ++_kw) {
int32_t w_im = _w * stride_w - pad_w + _kw * dilation_w;
int32_t kp = _kh * kernel_w + _kw;

// h_im and w_im are the actual height and width coordinates of the
// input tensor from where we need to copy 'channels' points.
const int8_t *__restrict__ slice_im =
data_im + (h_im * width + w_im) * channels;
int8_t *__restrict__ slice_col = data_col + i_col * channels_col +
(_kh * kernel_w + _kw) * channels;
// If the coordinates were within the input domain, we copy
// 'channels' contiguous values. Otherwise we will fill the output
// with 0's.
if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
const ae_int32x2 *pae_inp = (const ae_int32x2 *)slice_im;
ae_int32x2 *pae_out = (ae_int32x2 *)slice_col;
ae_valign inp_a, out_a;
inp_a = AE_LA64_PP(pae_inp);
out_a = AE_ZALIGN64();

ae_int32x2 d0;
for (int ic = 0; ic < channels >> 3; ic++) {
AE_LA32X2_IP(d0, inp_a, pae_inp);
AE_SA32X2_IP(d0, out_a, pae_out);
}
AE_SA64POS_FP(out_a, pae_out);

int remainder = channels & 7;
int8_t *ptmp_in = (int8_t *)pae_inp;
int8_t *ptmp_out = (int8_t *)pae_out;
for (int ic = 0; ic < remainder; ic++) {
*ptmp_out++ = *ptmp_in++;
const int8_t *__restrict__ pixel =
data_im + (h_im * width + w_im) * channels;
for (int _c = 0; _c < channels; ++_c) {
data_col[i_col * channels_col + _c * num_kp + kp] = pixel[_c];
}
} else {
for (int i = 0; i < channels; i++) {
slice_col[i] = (int8_t)(in_zero_point);
for (int _c = 0; _c < channels; ++_c) {
data_col[i_col * channels_col + _c * num_kp + kp] =
(int8_t)(in_zero_point);
}
}
}
Expand Down
Loading