Skip to content

[Arm Ethos-U] BatchNorm not delegated to the NPU #17397

@jonasdaugalas

Description

@jonasdaugalas

🐛 Describe the bug

A simple Batch Normalization is not being delegated to the U55 NPU. The same model succeeds via tflite-> Vela flow.

reproduce/batchnorm.py:

import torch


ModelUnderTest = torch.nn.BatchNorm2d(num_features=256)
ModelUnderTest.weight.data.fill_(1.5)
ModelUnderTest.bias.data.fill_(0.5)
ModelInputs = (torch.randn(1, 256, 20, 20),)
(.venv) root@container:/opt/executorch# python -m examples.arm.aot_arm_compiler --model_name reproduce/batchnorm.py --target ethos-u55-128 --system_config Ethos_U55_High_End_Embedded --memory_mode Shared_Sram --output reproduce/model.pte --quantize --delegate --intermediates reproduce/intermediates
W0211 21:18:16.321000 275320 torch/utils/flop_counter.py:45] triton not found; flop counting will not work for triton kernels
/usr/lib/python3.12/copyreg.py:99: FutureWarning: `isinstance(treespec, LeafSpec)` is deprecated, use `isinstance(treespec, TreeSpec) and treespec.is_leaf()` instead.
  return cls.__new__(cls, *args)
PTE file saved as reproduce/model.pte

(.venv) root@container:/opt/executorch# cat reproduce/intermediates/delegation_info.txt 
Delegation info:
Total delegated subgraphs: 0
Number of delegated nodes: 0
Number of non-delegated nodes: 2

Delegation table:
╒════╤═══════════════════════════════════════════════════╤═══════════════════════════════════╤═══════════════════════════════════════╕
│    │ op_type                                           │   occurrences_in_delegated_graphs │   occurrences_in_non_delegated_graphs │
╞════╪═══════════════════════════════════════════════════╪═══════════════════════════════════╪═══════════════════════════════════════╡
│  0 │ aten__native_batch_norm_legit_no_training_default │                                 0 │                                     1 │
├────┼───────────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  1 │ getitem                                           │                                 0 │                                     1 │
├────┼───────────────────────────────────────────────────┼───────────────────────────────────┼───────────────────────────────────────┤
│  2 │ Total                                             │                                 0 │                                     2 │
╘════╧═══════════════════════════════════════════════════╧═══════════════════════════════════╧═══════════════════════════════════════╛

This is the ACTUAL observed behavior.

The EXPECTED behavior is that the batchnorm is delegated to the NPU, like it happens via tflite -> Vela path:

import subprocess
from pathlib import Path
import torch
import numpy as np
import litert_torch
import tensorflow as tf

OUTPUT_DIR = Path(f"out_{Path(__file__).stem}")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)


input_shape = (256, 20, 20)
batch_size = 1
inputs = (torch.randn(batch_size, *input_shape),)

torch_model = torch.nn.BatchNorm2d(num_features=256)
torch_model.weight.data.fill_(1.5)
torch_model.bias.data.fill_(0.5)


def representative_dataset(num_samples: int = 10):
    rng = np.random.default_rng(seed=0)
    for _ in range(num_samples):
        sample = rng.normal(loc=0.0, scale=1.0, size=input_shape).astype(np.float32)
        yield [sample]


tfl_converter_flags = {
    "optimizations": [tf.lite.Optimize.DEFAULT],
    "representative_dataset": representative_dataset,
    "target_spec": tf.lite.TargetSpec(supported_ops=[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]),
    "inference_input_type": tf.int8,
    "inference_output_type": tf.int8,
}
edge_model = litert_torch.convert(torch_model.eval(), inputs, _ai_edge_converter_flags=tfl_converter_flags)
int8_path = OUTPUT_DIR / "model_int8.tflite"
edge_model.export(int8_path)


cmd = [
    "uvx",
    "--from",
    "ethos-u-vela",
    "vela",
    str(int8_path),
    "--output-dir",
    OUTPUT_DIR.as_posix(),
    "--accelerator-config",
    "ethos-u55-128",
    "--system-config",
    "Ethos_U55_High_End_Embedded",
    "--memory-mode",
    "Shared_Sram",
    "--config",
    "Arm/vela.ini",
]

result = subprocess.run(cmd, check=True, capture_output=True, text=True)
log_path = OUTPUT_DIR / "vela.log"
log_path.write_text(result.stdout)
print(result.stdout)

The output of this script indicates that all two operators were delegated:

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1770844498.199265  130138 port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
I0000 00:00:1770844498.280759  130138 cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1770844500.390885  130138 port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
Skipping import of cpp extensions due to incompatible torch version 2.9.1+cu128 for torchao version 0.16.0             Please see https://github.com/pytorch/ao/issues/2919 for more info
/home/wsluser/proj/reproduce/.venv/lib/python3.12/site-packages/torch/distributed/distributed_c10d.py:351: UserWarning: Device capability of jax unspecified, assuming `cpu` and `cuda`. Please specify it via the `devices` argument of `register_backend`.
  warnings.warn(
WARNING:2026-02-11 21:15:02,248:jax._src.xla_bridge:876: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
W0000 00:00:1770844504.081980  130138 tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
W0000 00:00:1770844504.082040  130138 tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
I0000 00:00:1770844504.082583  130138 reader.cc:83] Reading SavedModel from: /tmp/tmpmyfj_o79
I0000 00:00:1770844504.083049  130138 reader.cc:52] Reading meta graph with tags { serve }
I0000 00:00:1770844504.083092  130138 reader.cc:147] Reading SavedModel debug info (if present) from: /tmp/tmpmyfj_o79
I0000 00:00:1770844504.086412  130138 mlir_graph_optimization_pass.cc:437] MLIR V1 optimization pass is not enabled
I0000 00:00:1770844504.086934  130138 loader.cc:236] Restoring SavedModel bundle.
I0000 00:00:1770844504.110203  130138 loader.cc:220] Running initialization op on SavedModel bundle at path: /tmp/tmpmyfj_o79
I0000 00:00:1770844504.119111  130138 loader.cc:471] SavedModel load for tags { serve }; Status: success: OK. Took 36544 microseconds.
I0000 00:00:1770844504.131545  130138 dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:00:1770844504.199694  130138 flatbuffer_export.cc:4160] Estimated count of arithmetic ops: 0.205 M  ops, equivalently 0.102 M  MACs
fully_quantize: 0, inference_type: 6, input_inference_type: INT8, output_inference_type: INT8
I0000 00:00:1770844504.255401  130138 flatbuffer_export.cc:4160] Estimated count of arithmetic ops: 0.205 M  ops, equivalently 0.102 M  MACs
W0000 00:00:1770844504.255467  130138 flatbuffer_export.cc:3715] Skipping runtime version metadata in the model. This will be generated by the exporter.

Network summary for model_int8
Accelerator configuration               Ethos_U55_128
System configuration             Ethos_U55_High_End_Embedded
Memory mode                               Shared_Sram
Accelerator clock                                 500 MHz
Design peak SRAM bandwidth                       3.73 GB/s
Design peak Off-chip Flash bandwidth             0.47 GB/s

Total SRAM used                                100.00 KiB
Total Off-chip Flash used                        0.91 KiB

CPU operators = 0 (0.0%)
NPU operators = 2 (100.0%)

Average SRAM bandwidth                           0.12 GB/s
Input   SRAM bandwidth                           0.20 MB/batch
Weight  SRAM bandwidth                           0.00 MB/batch
Output  SRAM bandwidth                           0.20 MB/batch
Total   SRAM bandwidth                           0.39 MB/batch
Total   SRAM bandwidth            per input      0.39 MB/inference (batch size 1)

Average Off-chip Flash bandwidth                 0.06 GB/s
Input   Off-chip Flash bandwidth                 0.20 MB/batch
Weight  Off-chip Flash bandwidth                 0.00 MB/batch
Output  Off-chip Flash bandwidth                 0.00 MB/batch
Total   Off-chip Flash bandwidth                 0.20 MB/batch
Total   Off-chip Flash bandwidth  per input      0.20 MB/inference (batch size 1)

Neural network macs                                 0 MACs/batch

Versions

Collecting environment information...
PyTorch version: 2.11.0.dev20251222+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.3 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: 18.1.3 (1ubuntu1)
CMake version: version 3.31.10
Libc version: glibc-2.39

Python version: 3.12.3 (main, Jan 8 2026, 11:30:50) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.6.87.2-microsoft-standard-WSL2-x86_64-with-glibc2.39
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 39 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 8
On-line CPU(s) list: 0-7
Vendor ID: GenuineIntel
Model name: 11th Gen Intel(R) Core(TM) i7-1185G7 @ 3.00GHz
CPU family: 6
Model: 140
Thread(s) per core: 2
Core(s) per socket: 4
Socket(s): 1
Stepping: 1
BogoMIPS: 5990.41
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology tsc_reliable nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves vnmi avx512vbmi umip avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq rdpid movdiri movdir64b fsrm avx512_vp2intersect md_clear flush_l1d arch_capabilities
Virtualization: VT-x
Hypervisor vendor: Microsoft
Virtualization type: full
L1d cache: 192 KiB (4 instances)
L1i cache: 128 KiB (4 instances)
L2 cache: 5 MiB (4 instances)
L3 cache: 12 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-7
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; Enhanced IBRS
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] executorch==1.2.0a0+6e1e6c9
[pip3] numpy==2.4.2
[pip3] optree==0.18.0
[pip3] pytorch_tokenizers==1.0.1
[pip3] torch==2.11.0.dev20251222+cpu
[pip3] torchao==0.16.0+git28306f085
[pip3] torchaudio==2.10.0.dev20251222+cpu
[pip3] torchdata==0.11.0
[pip3] torchsr==1.0.4
[pip3] torchtune==0.0.0
[pip3] torchvision==0.25.0.dev20251222+cpu
[conda] Could not collect

cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: armIssues related to arm backendpartner: armFor backend delegation, kernels, demo, etc. from the 3rd-party partner, Arm

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions