Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion backends/arm/operator_support/to_dim_order_copy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,26 @@ def _merge_supported_types(
torch.float32,
],
}
SUPPORTED_FP8E4M3_EXTENSION_DTYPES: SupportedTypeDict = {
torch.float16: [torch.float8_e4m3fn],
torch.float32: [torch.float8_e4m3fn],
torch.float8_e4m3fn: [torch.float16, torch.float32],
}
SUPPORTED_FP8E5M2_EXTENSION_DTYPES: SupportedTypeDict = {
torch.float16: [torch.float8_e5m2],
torch.float32: [torch.float8_e5m2],
torch.float8_e5m2: [torch.float16, torch.float32],
}
SUPPORTED_BF16_FP8E4M3_EXTENSION_DTYPES: SupportedTypeDict = {
torch.bfloat16: [torch.float8_e4m3fn],
torch.float8_e4m3fn: [torch.bfloat16],
}
SUPPORTED_BF16_FP8E5M2_EXTENSION_DTYPES: SupportedTypeDict = {
torch.bfloat16: [torch.float8_e5m2],
torch.float8_e5m2: [torch.bfloat16],
}

def is_node_tosa_supported(
def is_node_tosa_supported( # noqa: C901
self, node: fx.Node, tosa_spec: TosaSpecification
) -> bool:
"""Return True if the node is supported by TOSA.
Expand All @@ -148,6 +166,26 @@ def is_node_tosa_supported(
supported_dtypes = self._merge_supported_types(
self.SUPPORTED_BF16_EXTENSION_DTYPES, supported_dtypes
)
if tosa_spec.support_extension("fp8e4m3"):
supported_dtypes = self._merge_supported_types(
self.SUPPORTED_FP8E4M3_EXTENSION_DTYPES, supported_dtypes
)
if tosa_spec.support_extension("fp8e5m2"):
supported_dtypes = self._merge_supported_types(
self.SUPPORTED_FP8E5M2_EXTENSION_DTYPES, supported_dtypes
)
if tosa_spec.support_extension("bf16") and tosa_spec.support_extension(
"fp8e4m3"
):
supported_dtypes = self._merge_supported_types(
self.SUPPORTED_BF16_FP8E4M3_EXTENSION_DTYPES, supported_dtypes
)
if tosa_spec.support_extension("bf16") and tosa_spec.support_extension(
"fp8e5m2"
):
supported_dtypes = self._merge_supported_types(
self.SUPPORTED_BF16_FP8E5M2_EXTENSION_DTYPES, supported_dtypes
)

if len(node.all_input_nodes) != 1:
self.reporter.report_reject(
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@ def tosa_support_factory(
disallowed_dtypes = [torch.float64]
if not tosa_spec.support_extension("bf16"):
disallowed_dtypes.append(torch.bfloat16)
if not tosa_spec.support_extension("fp8e4m3"):
disallowed_dtypes.append(torch.float8_e4m3fn)
if not tosa_spec.support_extension("fp8e5m2"):
disallowed_dtypes.append(torch.float8_e5m2)
if tosa_spec.is_U55_subset:
disallowed_dtypes.append(torch.bool)
negative_checks.append(
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operators/op_tosa_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def define_node(
ts.DType.FP16,
ts.DType.FP32,
ts.DType.BF16,
ts.DType.FP8E4M3,
ts.DType.FP8E5M2,
],
self.tosa_spec,
)
Expand Down
14 changes: 11 additions & 3 deletions backends/arm/process_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,22 @@

def _tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
tensor = tensor.detach().cpu().contiguous()
if tensor.dtype == torch.bfloat16:
if tensor.dtype in (torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2):
try:
import ml_dtypes # type: ignore[import-not-found]
except ImportError as e:
raise RuntimeError(
"ml_dtypes is required to serialize bfloat16 tensors for TOSA. Have you run setup.sh?"
f"ml_dtypes is required to serialize {tensor.dtype} tensors for TOSA. "
"Have you run setup.sh?"
) from e
return tensor.view(torch.uint16).numpy().view(ml_dtypes.bfloat16)

ml_dtype_map = {
torch.bfloat16: (torch.uint16, ml_dtypes.bfloat16),
torch.float8_e4m3fn: (torch.uint8, ml_dtypes.float8_e4m3fn),
torch.float8_e5m2: (torch.uint8, ml_dtypes.float8_e5m2),
}
storage_dtype, ml_dtype = ml_dtype_map[tensor.dtype]
return tensor.view(storage_dtype).numpy().view(ml_dtype)
else:
return tensor.numpy()

Expand Down
124 changes: 124 additions & 0 deletions backends/arm/test/ops/test_to_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,130 @@ def test_to_tosa_FP_bf16_with_extension():
pipeline.run()


_TO_COPY_TEST_DATA_FP_FP8 = {
"fp32_to_fp8e4m3": lambda: (
torch.rand((1, 2, 3, 4), dtype=torch.float32),
torch.float8_e4m3fn,
"fp8e4m3",
),
"fp16_to_fp8e5m2": lambda: (
torch.rand((1, 2, 3, 4), dtype=torch.float16),
torch.float8_e5m2,
"fp8e5m2",
),
"fp8e4m3_to_fp32": lambda: (
torch.rand((1, 2, 3, 4), dtype=torch.float32).to(torch.float8_e4m3fn),
torch.float32,
"fp8e4m3",
),
"fp8e5m2_to_fp16": lambda: (
torch.rand((1, 2, 3, 4), dtype=torch.float32).to(torch.float8_e5m2),
torch.float16,
"fp8e5m2",
),
}


def test_to_tosa_FP_fp8e4m3_requires_extension():
test_tensor = torch.rand((1, 2, 3, 4), dtype=torch.float32)
pipeline = OpNotSupportedPipeline[input_t1](
Cast(torch.float8_e4m3fn),
(test_tensor,),
{
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1
},
)
pipeline.run()


def test_to_tosa_FP_fp8e5m2_requires_extension():
test_tensor = torch.rand((1, 2, 3, 4), dtype=torch.float16)
pipeline = OpNotSupportedPipeline[input_t1](
Cast(torch.float8_e5m2),
(test_tensor,),
{
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1
},
)
pipeline.run()


def test_to_tosa_FP_bf16_to_fp8e4m3_requires_both_extensions():
test_tensor = torch.rand((1, 2, 3, 4), dtype=torch.bfloat16)
pipeline = OpNotSupportedPipeline[input_t1](
Cast(torch.float8_e4m3fn),
(test_tensor,),
{
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1
},
tosa_extensions=["bf16"],
)
pipeline.run()


def test_to_tosa_FP_bf16_to_fp8e5m2_requires_both_extensions():
test_tensor = torch.rand((1, 2, 3, 4), dtype=torch.bfloat16)
pipeline = OpNotSupportedPipeline[input_t1](
Cast(torch.float8_e5m2),
(test_tensor,),
{
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1
},
tosa_extensions=["bf16"],
)
pipeline.run()


@common.parametrize("test_data", _TO_COPY_TEST_DATA_FP_FP8)
def test_to_tosa_FP_fp8_with_extension(test_data: Tuple):
test_tensor, new_dtype, tosa_extension = test_data()
pipeline = TosaPipelineFP[input_t1](
Cast(new_dtype),
(test_tensor,),
aten_op=[],
exir_op=[],
tosa_extensions=[tosa_extension],
)
pipeline.run()


_TO_COPY_TEST_DATA_BF16_FP8 = {
"bf16_to_fp8e4m3": lambda: (
torch.rand((1, 2, 3, 4), dtype=torch.bfloat16),
torch.float8_e4m3fn,
["bf16", "fp8e4m3"],
),
"fp8e4m3_to_bf16": lambda: (
torch.rand((1, 2, 3, 4), dtype=torch.float32).to(torch.float8_e4m3fn),
torch.bfloat16,
["bf16", "fp8e4m3"],
),
"bf16_to_fp8e5m2": lambda: (
torch.rand((1, 2, 3, 4), dtype=torch.bfloat16),
torch.float8_e5m2,
["bf16", "fp8e5m2"],
),
"fp8e5m2_to_bf16": lambda: (
torch.rand((1, 2, 3, 4), dtype=torch.float32).to(torch.float8_e5m2),
torch.bfloat16,
["bf16", "fp8e5m2"],
),
}


@common.parametrize("test_data", _TO_COPY_TEST_DATA_BF16_FP8)
def test_to_tosa_FP_bf16_fp8_with_extensions(test_data: Tuple):
test_tensor, new_dtype, tosa_extensions = test_data()
pipeline = TosaPipelineFP[input_t1](
Cast(new_dtype),
(test_tensor,),
aten_op=[],
exir_op=[],
tosa_extensions=tosa_extensions,
)
pipeline.run()


@common.parametrize("test_data", _TO_COPY_TEST_DATA_FP)
@common.SkipIfNoModelConverter
def test_to_vgf_no_quant(test_data: Tuple):
Expand Down
6 changes: 6 additions & 0 deletions backends/arm/test/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
torch.int16: np.int16,
torch.int32: np.int32,
torch.int64: np.int64,
torch.float8_e4m3fn: np.uint8,
torch.float8_e5m2: np.uint8,
torch.float16: np.float16,
torch.float32: np.float32,
torch.float64: np.float64,
Expand Down Expand Up @@ -190,6 +192,10 @@ def torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray:
if tensor.dtype == torch.bfloat16:
# Numpy doesn't support bfloat16, use, uint16 instead. Dtype is inferred from model anyways.
tensor = tensor.view(torch.uint16)
elif tensor.dtype == torch.float8_e4m3fn:
tensor = tensor.view(torch.uint8)
elif tensor.dtype == torch.float8_e5m2:
tensor = tensor.view(torch.uint8)
return tensor.numpy()


Expand Down
8 changes: 8 additions & 0 deletions backends/arm/tosa/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def map_dtype(data_type: torch.dtype) -> Any:
torch.int32: ts.DType.INT32,
torch.int: ts.DType.INT32,
torch.bool: ts.DType.BOOL,
torch.float8_e4m3fn: ts.DType.FP8E4M3,
torch.float8_e5m2: ts.DType.FP8E5M2,
}
if data_type not in dtype_map:
raise ValueError(f"Unknown type: {data_type}")
Expand Down Expand Up @@ -231,6 +233,12 @@ def __validate(self, tosa_spec: TosaSpecification) -> bool:
case ts.DType.BF16:
if not tosa_spec.support_extension("bf16"):
return False
case ts.DType.FP8E4M3:
if not tosa_spec.support_extension("fp8e4m3"):
return False
case ts.DType.FP8E5M2:
if not tosa_spec.support_extension("fp8e5m2"):
return False

return True

Expand Down
15 changes: 11 additions & 4 deletions backends/test/harness/tester.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# Copyright 2025-2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -439,17 +439,24 @@ def _assert_outputs_equal(
f"\tMismatched count: {(model != ref).sum().item()} / {model.numel()}\n"
)
else:
# torch.allclose() does not have a CPU implementation for FP8 tensors
# in some PyTorch builds, so compare FP8 outputs in float32 instead.
compare_model = model
compare_ref = ref
if model.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
compare_model = model.to(torch.float32)
compare_ref = ref.to(torch.float32)
assert torch.allclose(
model,
ref,
compare_model,
compare_ref,
atol=atol,
rtol=rtol,
equal_nan=True,
), (
f"Output {i} does not match reference output.\n"
f"\tGiven atol: {atol}, rtol: {rtol}.\n"
f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n"
f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref).to(torch.double))}.\n"
f"\tDifference: max: {torch.max(compare_model-compare_ref)}, abs: {torch.max(torch.abs(compare_model-compare_ref))}, mean abs error: {torch.mean(torch.abs(compare_model-compare_ref).to(torch.double))}.\n"
f"\t-- Model vs. Reference --\n"
f"\t Numel: {model.numel()}, {ref.numel()}\n"
f"\tMedian: {model.median()}, {ref.median()}\n"
Expand Down
Loading