Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions kernels/optimized/cpu/op_div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ namespace native {
namespace {

ScalarType get_common_type(ScalarType a_type, ScalarType b_type) {
ET_CHECK(
!isComplexType(a_type) && !isQIntType(a_type) && !isBitsType(a_type));
ET_CHECK(
!isComplexType(b_type) && !isQIntType(b_type) && !isBitsType(b_type));
if (isComplexType(a_type) || isComplexType(b_type)) {
return promoteTypes(a_type, b_type);
}
ET_CHECK(!isQIntType(a_type) && !isBitsType(a_type));
ET_CHECK(!isQIntType(b_type) && !isBitsType(b_type));

if (isFloatingType(a_type) && isFloatingType(b_type)) {
return promoteTypes(a_type, b_type);
Expand Down Expand Up @@ -61,6 +62,20 @@ Tensor& opt_div_out(
ScalarType b_type = b.scalar_type();
ScalarType out_type = out.scalar_type();

// Handle complex types
if (isComplexType(a_type) || isComplexType(b_type)) {
ScalarType common_type = get_common_type(a_type, b_type);
ET_SWITCH_COMPLEX_TYPES(common_type, ctx, op_name, CTYPE, [&]() {
const CTYPE* a_data = a.const_data_ptr<CTYPE>();
const CTYPE* b_data = b.const_data_ptr<CTYPE>();
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
for (size_t i = 0; i < out.numel(); ++i) {
out_data[i] = a_data[i] / b_data[i];
}
});
return out;
}

if (a.numel() == 1 || b.numel() == 1) {
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
a_type != ScalarType::BFloat16) {
Expand Down
47 changes: 30 additions & 17 deletions kernels/portable/cpu/op_div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ namespace native {
namespace {

ScalarType get_common_type(ScalarType a_type, ScalarType b_type) {
if (isFloatingType(a_type) && isFloatingType(b_type)) {
if (executorch::runtime::isComplexType(a_type) ||
executorch::runtime::isComplexType(b_type)) {
return promoteTypes(a_type, b_type);
} else if (isFloatingType(a_type) && isFloatingType(b_type)) {
return promoteTypes(a_type, b_type);
} else if (isFloatingType(a_type)) {
return a_type;
Expand Down Expand Up @@ -51,25 +54,35 @@ Tensor& div_out(
InvalidArgument,
out);

// Compute Dtype
ScalarType compute_type = utils::get_compute_type(common_type);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "div.out";

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::FLOATHBF16>(
[](const auto& val_a, const auto& val_b) { return val_a / val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out);
});
if (executorch::runtime::isComplexType(common_type)) {
ET_SWITCH_COMPLEX_TYPES(common_type, ctx, op_name, CTYPE, [&]() {
const CTYPE* a_data = a.const_data_ptr<CTYPE>();
const CTYPE* b_data = b.const_data_ptr<CTYPE>();
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
for (ssize_t i = 0; i < out.numel(); ++i) {
out_data[i] = a_data[i] / b_data[i];
}
});
} else {
// Compute Dtype for real types
ScalarType compute_type = utils::get_compute_type(common_type);
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::FLOATHBF16>(
[](const auto& val_a, const auto& val_b) { return val_a / val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out);
});
}

return out;
}
Expand Down
95 changes: 95 additions & 0 deletions kernels/test/op_div_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,3 +601,98 @@ TEST_F(OpDivScalarOutTest, OptimizedSanityCheck) {
// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, {0.65, 1.05, 2.3, 4.1}));
}

//
// Complex Type Tests
//

TEST_F(OpDivOutTest, ComplexFloatBasic) {
TensorFactory<ScalarType::ComplexFloat> tf;

const std::vector<int32_t> sizes = {2, 2};

// (1+2i) / (1+0i) = (1+2i)
// (4+4i) / (2+0i) = (2+2i)
// (3+4i) / (1-1i) = (3+4i)(1+1i) / 2 = (-1+7i) / 2 = (-0.5+3.5i)
// (8+0i) / (2+2i) = (8)(2-2i) / 8 = (2-2i)
Tensor a = tf.make(
sizes,
{executorch::aten::complex<float>(1.0f, 2.0f),
executorch::aten::complex<float>(4.0f, 4.0f),
executorch::aten::complex<float>(3.0f, 4.0f),
executorch::aten::complex<float>(8.0f, 0.0f)});

Tensor b = tf.make(
sizes,
{executorch::aten::complex<float>(1.0f, 0.0f),
executorch::aten::complex<float>(2.0f, 0.0f),
executorch::aten::complex<float>(1.0f, -1.0f),
executorch::aten::complex<float>(2.0f, 2.0f)});

Tensor out = tf.zeros(sizes);

op_div_out(a, b, out);

Tensor expected = tf.make(
sizes,
{executorch::aten::complex<float>(1.0f, 2.0f),
executorch::aten::complex<float>(2.0f, 2.0f),
executorch::aten::complex<float>(-0.5f, 3.5f),
executorch::aten::complex<float>(2.0f, -2.0f)});

EXPECT_TENSOR_CLOSE(out, expected);
}

TEST_F(OpDivOutTest, ComplexDoubleBasic) {
TensorFactory<ScalarType::ComplexDouble> tf;

const std::vector<int32_t> sizes = {2};

Tensor a = tf.make(
sizes,
{executorch::aten::complex<double>(6.0, 8.0),
executorch::aten::complex<double>(4.0, 0.0)});

Tensor b = tf.make(
sizes,
{executorch::aten::complex<double>(2.0, 0.0),
executorch::aten::complex<double>(0.0, 2.0)});

Tensor out = tf.zeros(sizes);

op_div_out(a, b, out);

// (6+8i) / 2 = (3+4i)
// 4 / 2i = 4 * (-i) / 2 = -2i = (0-2i)
Tensor expected = tf.make(
sizes,
{executorch::aten::complex<double>(3.0, 4.0),
executorch::aten::complex<double>(0.0, -2.0)});

EXPECT_TENSOR_CLOSE(out, expected);
}

TEST_F(OpDivOutTest, ComplexFloatIdentity) {
TensorFactory<ScalarType::ComplexFloat> tf;

const std::vector<int32_t> sizes = {3};

// Dividing by 1 should return the same value
Tensor a = tf.make(
sizes,
{executorch::aten::complex<float>(1.0f, 2.0f),
executorch::aten::complex<float>(3.0f, 4.0f),
executorch::aten::complex<float>(-5.0f, 6.0f)});

Tensor one = tf.make(
sizes,
{executorch::aten::complex<float>(1.0f, 0.0f),
executorch::aten::complex<float>(1.0f, 0.0f),
executorch::aten::complex<float>(1.0f, 0.0f)});

Tensor out = tf.zeros(sizes);

op_div_out(a, one, out);

EXPECT_TENSOR_CLOSE(out, a);
}
Loading