diff --git a/backends/arm/operator_support/to_dim_order_copy_support.py b/backends/arm/operator_support/to_dim_order_copy_support.py index a1560ba8cd8..a02a8e16276 100644 --- a/backends/arm/operator_support/to_dim_order_copy_support.py +++ b/backends/arm/operator_support/to_dim_order_copy_support.py @@ -120,8 +120,26 @@ def _merge_supported_types( torch.float32, ], } + SUPPORTED_FP8E4M3_EXTENSION_DTYPES: SupportedTypeDict = { + torch.float16: [torch.float8_e4m3fn], + torch.float32: [torch.float8_e4m3fn], + torch.float8_e4m3fn: [torch.float16, torch.float32], + } + SUPPORTED_FP8E5M2_EXTENSION_DTYPES: SupportedTypeDict = { + torch.float16: [torch.float8_e5m2], + torch.float32: [torch.float8_e5m2], + torch.float8_e5m2: [torch.float16, torch.float32], + } + SUPPORTED_BF16_FP8E4M3_EXTENSION_DTYPES: SupportedTypeDict = { + torch.bfloat16: [torch.float8_e4m3fn], + torch.float8_e4m3fn: [torch.bfloat16], + } + SUPPORTED_BF16_FP8E5M2_EXTENSION_DTYPES: SupportedTypeDict = { + torch.bfloat16: [torch.float8_e5m2], + torch.float8_e5m2: [torch.bfloat16], + } - def is_node_tosa_supported( + def is_node_tosa_supported( # noqa: C901 self, node: fx.Node, tosa_spec: TosaSpecification ) -> bool: """Return True if the node is supported by TOSA. @@ -148,6 +166,26 @@ def is_node_tosa_supported( supported_dtypes = self._merge_supported_types( self.SUPPORTED_BF16_EXTENSION_DTYPES, supported_dtypes ) + if tosa_spec.support_extension("fp8e4m3"): + supported_dtypes = self._merge_supported_types( + self.SUPPORTED_FP8E4M3_EXTENSION_DTYPES, supported_dtypes + ) + if tosa_spec.support_extension("fp8e5m2"): + supported_dtypes = self._merge_supported_types( + self.SUPPORTED_FP8E5M2_EXTENSION_DTYPES, supported_dtypes + ) + if tosa_spec.support_extension("bf16") and tosa_spec.support_extension( + "fp8e4m3" + ): + supported_dtypes = self._merge_supported_types( + self.SUPPORTED_BF16_FP8E4M3_EXTENSION_DTYPES, supported_dtypes + ) + if tosa_spec.support_extension("bf16") and tosa_spec.support_extension( + "fp8e5m2" + ): + supported_dtypes = self._merge_supported_types( + self.SUPPORTED_BF16_FP8E5M2_EXTENSION_DTYPES, supported_dtypes + ) if len(node.all_input_nodes) != 1: self.reporter.report_reject( diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 728012ad457..0819e137827 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -295,6 +295,10 @@ def tosa_support_factory( disallowed_dtypes = [torch.float64] if not tosa_spec.support_extension("bf16"): disallowed_dtypes.append(torch.bfloat16) + if not tosa_spec.support_extension("fp8e4m3"): + disallowed_dtypes.append(torch.float8_e4m3fn) + if not tosa_spec.support_extension("fp8e5m2"): + disallowed_dtypes.append(torch.float8_e5m2) if tosa_spec.is_U55_subset: disallowed_dtypes.append(torch.bool) negative_checks.append( diff --git a/backends/arm/operators/op_tosa_identity.py b/backends/arm/operators/op_tosa_identity.py index 53d75c8f24c..1b15e39154e 100644 --- a/backends/arm/operators/op_tosa_identity.py +++ b/backends/arm/operators/op_tosa_identity.py @@ -46,6 +46,8 @@ def define_node( ts.DType.FP16, ts.DType.FP32, ts.DType.BF16, + ts.DType.FP8E4M3, + ts.DType.FP8E5M2, ], self.tosa_spec, ) diff --git a/backends/arm/process_node.py b/backends/arm/process_node.py index e2d76927500..f86df9627ff 100644 --- a/backends/arm/process_node.py +++ b/backends/arm/process_node.py @@ -30,14 +30,22 @@ def _tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: tensor = tensor.detach().cpu().contiguous() - if tensor.dtype == torch.bfloat16: + if tensor.dtype in (torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2): try: import ml_dtypes # type: ignore[import-not-found] except ImportError as e: raise RuntimeError( - "ml_dtypes is required to serialize bfloat16 tensors for TOSA. Have you run setup.sh?" + f"ml_dtypes is required to serialize {tensor.dtype} tensors for TOSA. " + "Have you run setup.sh?" ) from e - return tensor.view(torch.uint16).numpy().view(ml_dtypes.bfloat16) + + ml_dtype_map = { + torch.bfloat16: (torch.uint16, ml_dtypes.bfloat16), + torch.float8_e4m3fn: (torch.uint8, ml_dtypes.float8_e4m3fn), + torch.float8_e5m2: (torch.uint8, ml_dtypes.float8_e5m2), + } + storage_dtype, ml_dtype = ml_dtype_map[tensor.dtype] + return tensor.view(storage_dtype).numpy().view(ml_dtype) else: return tensor.numpy() diff --git a/backends/arm/test/ops/test_to_copy.py b/backends/arm/test/ops/test_to_copy.py index 00e141a37a2..6718fedea04 100644 --- a/backends/arm/test/ops/test_to_copy.py +++ b/backends/arm/test/ops/test_to_copy.py @@ -116,6 +116,130 @@ def test_to_tosa_FP_bf16_with_extension(): pipeline.run() +_TO_COPY_TEST_DATA_FP_FP8 = { + "fp32_to_fp8e4m3": lambda: ( + torch.rand((1, 2, 3, 4), dtype=torch.float32), + torch.float8_e4m3fn, + "fp8e4m3", + ), + "fp16_to_fp8e5m2": lambda: ( + torch.rand((1, 2, 3, 4), dtype=torch.float16), + torch.float8_e5m2, + "fp8e5m2", + ), + "fp8e4m3_to_fp32": lambda: ( + torch.rand((1, 2, 3, 4), dtype=torch.float32).to(torch.float8_e4m3fn), + torch.float32, + "fp8e4m3", + ), + "fp8e5m2_to_fp16": lambda: ( + torch.rand((1, 2, 3, 4), dtype=torch.float32).to(torch.float8_e5m2), + torch.float16, + "fp8e5m2", + ), +} + + +def test_to_tosa_FP_fp8e4m3_requires_extension(): + test_tensor = torch.rand((1, 2, 3, 4), dtype=torch.float32) + pipeline = OpNotSupportedPipeline[input_t1]( + Cast(torch.float8_e4m3fn), + (test_tensor,), + { + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1 + }, + ) + pipeline.run() + + +def test_to_tosa_FP_fp8e5m2_requires_extension(): + test_tensor = torch.rand((1, 2, 3, 4), dtype=torch.float16) + pipeline = OpNotSupportedPipeline[input_t1]( + Cast(torch.float8_e5m2), + (test_tensor,), + { + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1 + }, + ) + pipeline.run() + + +def test_to_tosa_FP_bf16_to_fp8e4m3_requires_both_extensions(): + test_tensor = torch.rand((1, 2, 3, 4), dtype=torch.bfloat16) + pipeline = OpNotSupportedPipeline[input_t1]( + Cast(torch.float8_e4m3fn), + (test_tensor,), + { + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1 + }, + tosa_extensions=["bf16"], + ) + pipeline.run() + + +def test_to_tosa_FP_bf16_to_fp8e5m2_requires_both_extensions(): + test_tensor = torch.rand((1, 2, 3, 4), dtype=torch.bfloat16) + pipeline = OpNotSupportedPipeline[input_t1]( + Cast(torch.float8_e5m2), + (test_tensor,), + { + "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 1 + }, + tosa_extensions=["bf16"], + ) + pipeline.run() + + +@common.parametrize("test_data", _TO_COPY_TEST_DATA_FP_FP8) +def test_to_tosa_FP_fp8_with_extension(test_data: Tuple): + test_tensor, new_dtype, tosa_extension = test_data() + pipeline = TosaPipelineFP[input_t1]( + Cast(new_dtype), + (test_tensor,), + aten_op=[], + exir_op=[], + tosa_extensions=[tosa_extension], + ) + pipeline.run() + + +_TO_COPY_TEST_DATA_BF16_FP8 = { + "bf16_to_fp8e4m3": lambda: ( + torch.rand((1, 2, 3, 4), dtype=torch.bfloat16), + torch.float8_e4m3fn, + ["bf16", "fp8e4m3"], + ), + "fp8e4m3_to_bf16": lambda: ( + torch.rand((1, 2, 3, 4), dtype=torch.float32).to(torch.float8_e4m3fn), + torch.bfloat16, + ["bf16", "fp8e4m3"], + ), + "bf16_to_fp8e5m2": lambda: ( + torch.rand((1, 2, 3, 4), dtype=torch.bfloat16), + torch.float8_e5m2, + ["bf16", "fp8e5m2"], + ), + "fp8e5m2_to_bf16": lambda: ( + torch.rand((1, 2, 3, 4), dtype=torch.float32).to(torch.float8_e5m2), + torch.bfloat16, + ["bf16", "fp8e5m2"], + ), +} + + +@common.parametrize("test_data", _TO_COPY_TEST_DATA_BF16_FP8) +def test_to_tosa_FP_bf16_fp8_with_extensions(test_data: Tuple): + test_tensor, new_dtype, tosa_extensions = test_data() + pipeline = TosaPipelineFP[input_t1]( + Cast(new_dtype), + (test_tensor,), + aten_op=[], + exir_op=[], + tosa_extensions=tosa_extensions, + ) + pipeline.run() + + @common.parametrize("test_data", _TO_COPY_TEST_DATA_FP) @common.SkipIfNoModelConverter def test_to_vgf_no_quant(test_data: Tuple): diff --git a/backends/arm/test/runner_utils.py b/backends/arm/test/runner_utils.py index 93887fbda6b..5e62c4506f9 100644 --- a/backends/arm/test/runner_utils.py +++ b/backends/arm/test/runner_utils.py @@ -62,6 +62,8 @@ torch.int16: np.int16, torch.int32: np.int32, torch.int64: np.int64, + torch.float8_e4m3fn: np.uint8, + torch.float8_e5m2: np.uint8, torch.float16: np.float16, torch.float32: np.float32, torch.float64: np.float64, @@ -190,6 +192,10 @@ def torch_tensor_to_numpy(tensor: torch.Tensor) -> np.ndarray: if tensor.dtype == torch.bfloat16: # Numpy doesn't support bfloat16, use, uint16 instead. Dtype is inferred from model anyways. tensor = tensor.view(torch.uint16) + elif tensor.dtype == torch.float8_e4m3fn: + tensor = tensor.view(torch.uint8) + elif tensor.dtype == torch.float8_e5m2: + tensor = tensor.view(torch.uint8) return tensor.numpy() diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 9741ca167c0..b37c41a070b 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -106,6 +106,8 @@ def map_dtype(data_type: torch.dtype) -> Any: torch.int32: ts.DType.INT32, torch.int: ts.DType.INT32, torch.bool: ts.DType.BOOL, + torch.float8_e4m3fn: ts.DType.FP8E4M3, + torch.float8_e5m2: ts.DType.FP8E5M2, } if data_type not in dtype_map: raise ValueError(f"Unknown type: {data_type}") @@ -231,6 +233,12 @@ def __validate(self, tosa_spec: TosaSpecification) -> bool: case ts.DType.BF16: if not tosa_spec.support_extension("bf16"): return False + case ts.DType.FP8E4M3: + if not tosa_spec.support_extension("fp8e4m3"): + return False + case ts.DType.FP8E5M2: + if not tosa_spec.support_extension("fp8e5m2"): + return False return True diff --git a/backends/test/harness/tester.py b/backends/test/harness/tester.py index ea5fd21cb99..d237d15d717 100644 --- a/backends/test/harness/tester.py +++ b/backends/test/harness/tester.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -439,9 +439,16 @@ def _assert_outputs_equal( f"\tMismatched count: {(model != ref).sum().item()} / {model.numel()}\n" ) else: + # torch.allclose() does not have a CPU implementation for FP8 tensors + # in some PyTorch builds, so compare FP8 outputs in float32 instead. + compare_model = model + compare_ref = ref + if model.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + compare_model = model.to(torch.float32) + compare_ref = ref.to(torch.float32) assert torch.allclose( - model, - ref, + compare_model, + compare_ref, atol=atol, rtol=rtol, equal_nan=True, @@ -449,7 +456,7 @@ def _assert_outputs_equal( f"Output {i} does not match reference output.\n" f"\tGiven atol: {atol}, rtol: {rtol}.\n" f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n" - f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref).to(torch.double))}.\n" + f"\tDifference: max: {torch.max(compare_model-compare_ref)}, abs: {torch.max(torch.abs(compare_model-compare_ref))}, mean abs error: {torch.mean(torch.abs(compare_model-compare_ref).to(torch.double))}.\n" f"\t-- Model vs. Reference --\n" f"\t Numel: {model.numel()}, {ref.numel()}\n" f"\tMedian: {model.median()}, {ref.median()}\n"