From 145e286aa709ebb4ed16910bc64d8e0dd502d9c5 Mon Sep 17 00:00:00 2001 From: George Tzoupis Date: Thu, 12 Feb 2026 08:01:05 -0800 Subject: [PATCH] Add complex type support to div operator (#17414) Summary: As titled Differential Revision: D93086411 --- kernels/optimized/cpu/op_div.cpp | 23 ++++++-- kernels/portable/cpu/op_div.cpp | 47 ++++++++++------ kernels/test/op_div_test.cpp | 95 ++++++++++++++++++++++++++++++++ 3 files changed, 144 insertions(+), 21 deletions(-) diff --git a/kernels/optimized/cpu/op_div.cpp b/kernels/optimized/cpu/op_div.cpp index d74a293af8a..c2da64e7088 100644 --- a/kernels/optimized/cpu/op_div.cpp +++ b/kernels/optimized/cpu/op_div.cpp @@ -21,10 +21,11 @@ namespace native { namespace { ScalarType get_common_type(ScalarType a_type, ScalarType b_type) { - ET_CHECK( - !isComplexType(a_type) && !isQIntType(a_type) && !isBitsType(a_type)); - ET_CHECK( - !isComplexType(b_type) && !isQIntType(b_type) && !isBitsType(b_type)); + if (isComplexType(a_type) || isComplexType(b_type)) { + return promoteTypes(a_type, b_type); + } + ET_CHECK(!isQIntType(a_type) && !isBitsType(a_type)); + ET_CHECK(!isQIntType(b_type) && !isBitsType(b_type)); if (isFloatingType(a_type) && isFloatingType(b_type)) { return promoteTypes(a_type, b_type); @@ -61,6 +62,20 @@ Tensor& opt_div_out( ScalarType b_type = b.scalar_type(); ScalarType out_type = out.scalar_type(); + // Handle complex types + if (isComplexType(a_type) || isComplexType(b_type)) { + ScalarType common_type = get_common_type(a_type, b_type); + ET_SWITCH_COMPLEX_TYPES(common_type, ctx, op_name, CTYPE, [&]() { + const CTYPE* a_data = a.const_data_ptr(); + const CTYPE* b_data = b.const_data_ptr(); + CTYPE* out_data = out.mutable_data_ptr(); + for (size_t i = 0; i < out.numel(); ++i) { + out_data[i] = a_data[i] / b_data[i]; + } + }); + return out; + } + if (a.numel() == 1 || b.numel() == 1) { if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { diff --git a/kernels/portable/cpu/op_div.cpp b/kernels/portable/cpu/op_div.cpp index f94f219d853..299997ad322 100644 --- a/kernels/portable/cpu/op_div.cpp +++ b/kernels/portable/cpu/op_div.cpp @@ -20,7 +20,10 @@ namespace native { namespace { ScalarType get_common_type(ScalarType a_type, ScalarType b_type) { - if (isFloatingType(a_type) && isFloatingType(b_type)) { + if (executorch::runtime::isComplexType(a_type) || + executorch::runtime::isComplexType(b_type)) { + return promoteTypes(a_type, b_type); + } else if (isFloatingType(a_type) && isFloatingType(b_type)) { return promoteTypes(a_type, b_type); } else if (isFloatingType(a_type)) { return a_type; @@ -51,25 +54,35 @@ Tensor& div_out( InvalidArgument, out); - // Compute Dtype - ScalarType compute_type = utils::get_compute_type(common_type); - // @lint-ignore CLANGTIDY facebook-hte-CArray static constexpr const char op_name[] = "div.out"; - ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { - utils::apply_bitensor_elementwise_fn< - CTYPE_COMPUTE, - op_name, - utils::SupportedTensorDtypes::FLOATHBF16>( - [](const auto& val_a, const auto& val_b) { return val_a / val_b; }, - ctx, - a, - utils::SupportedTensorDtypes::REALHBBF16, - b, - utils::SupportedTensorDtypes::REALHBBF16, - out); - }); + if (executorch::runtime::isComplexType(common_type)) { + ET_SWITCH_COMPLEX_TYPES(common_type, ctx, op_name, CTYPE, [&]() { + const CTYPE* a_data = a.const_data_ptr(); + const CTYPE* b_data = b.const_data_ptr(); + CTYPE* out_data = out.mutable_data_ptr(); + for (ssize_t i = 0; i < out.numel(); ++i) { + out_data[i] = a_data[i] / b_data[i]; + } + }); + } else { + // Compute Dtype for real types + ScalarType compute_type = utils::get_compute_type(common_type); + ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name, + utils::SupportedTensorDtypes::FLOATHBF16>( + [](const auto& val_a, const auto& val_b) { return val_a / val_b; }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out); + }); + } return out; } diff --git a/kernels/test/op_div_test.cpp b/kernels/test/op_div_test.cpp index d4ff2b99121..94f26d1b301 100644 --- a/kernels/test/op_div_test.cpp +++ b/kernels/test/op_div_test.cpp @@ -601,3 +601,98 @@ TEST_F(OpDivScalarOutTest, OptimizedSanityCheck) { // Check that it matches the expected output. EXPECT_TENSOR_CLOSE(out, tf.make(sizes, {0.65, 1.05, 2.3, 4.1})); } + +// +// Complex Type Tests +// + +TEST_F(OpDivOutTest, ComplexFloatBasic) { + TensorFactory tf; + + const std::vector sizes = {2, 2}; + + // (1+2i) / (1+0i) = (1+2i) + // (4+4i) / (2+0i) = (2+2i) + // (3+4i) / (1-1i) = (3+4i)(1+1i) / 2 = (-1+7i) / 2 = (-0.5+3.5i) + // (8+0i) / (2+2i) = (8)(2-2i) / 8 = (2-2i) + Tensor a = tf.make( + sizes, + {executorch::aten::complex(1.0f, 2.0f), + executorch::aten::complex(4.0f, 4.0f), + executorch::aten::complex(3.0f, 4.0f), + executorch::aten::complex(8.0f, 0.0f)}); + + Tensor b = tf.make( + sizes, + {executorch::aten::complex(1.0f, 0.0f), + executorch::aten::complex(2.0f, 0.0f), + executorch::aten::complex(1.0f, -1.0f), + executorch::aten::complex(2.0f, 2.0f)}); + + Tensor out = tf.zeros(sizes); + + op_div_out(a, b, out); + + Tensor expected = tf.make( + sizes, + {executorch::aten::complex(1.0f, 2.0f), + executorch::aten::complex(2.0f, 2.0f), + executorch::aten::complex(-0.5f, 3.5f), + executorch::aten::complex(2.0f, -2.0f)}); + + EXPECT_TENSOR_CLOSE(out, expected); +} + +TEST_F(OpDivOutTest, ComplexDoubleBasic) { + TensorFactory tf; + + const std::vector sizes = {2}; + + Tensor a = tf.make( + sizes, + {executorch::aten::complex(6.0, 8.0), + executorch::aten::complex(4.0, 0.0)}); + + Tensor b = tf.make( + sizes, + {executorch::aten::complex(2.0, 0.0), + executorch::aten::complex(0.0, 2.0)}); + + Tensor out = tf.zeros(sizes); + + op_div_out(a, b, out); + + // (6+8i) / 2 = (3+4i) + // 4 / 2i = 4 * (-i) / 2 = -2i = (0-2i) + Tensor expected = tf.make( + sizes, + {executorch::aten::complex(3.0, 4.0), + executorch::aten::complex(0.0, -2.0)}); + + EXPECT_TENSOR_CLOSE(out, expected); +} + +TEST_F(OpDivOutTest, ComplexFloatIdentity) { + TensorFactory tf; + + const std::vector sizes = {3}; + + // Dividing by 1 should return the same value + Tensor a = tf.make( + sizes, + {executorch::aten::complex(1.0f, 2.0f), + executorch::aten::complex(3.0f, 4.0f), + executorch::aten::complex(-5.0f, 6.0f)}); + + Tensor one = tf.make( + sizes, + {executorch::aten::complex(1.0f, 0.0f), + executorch::aten::complex(1.0f, 0.0f), + executorch::aten::complex(1.0f, 0.0f)}); + + Tensor out = tf.zeros(sizes); + + op_div_out(a, one, out); + + EXPECT_TENSOR_CLOSE(out, a); +}