From 5c066bbbc0e7852d120f6db18b8f3274147e2d27 Mon Sep 17 00:00:00 2001 From: Colin Unger Date: Fri, 29 May 2026 16:55:51 -0700 Subject: [PATCH] Add additional arithmetic operations for int_ge_two --- .../include/utils/int_ge_two/int_ge_two.h | 10 +++ lib/utils/src/utils/int_ge_two/int_ge_two.cc | 29 +++++++ .../test/src/utils/int_ge_two/int_ge_two.cc | 86 +++++++++++++++++++ 3 files changed, 125 insertions(+) create mode 100644 lib/utils/test/src/utils/int_ge_two/int_ge_two.cc diff --git a/lib/utils/include/utils/int_ge_two/int_ge_two.h b/lib/utils/include/utils/int_ge_two/int_ge_two.h index c22254b219..4f9e532513 100644 --- a/lib/utils/include/utils/int_ge_two/int_ge_two.h +++ b/lib/utils/include/utils/int_ge_two/int_ge_two.h @@ -87,6 +87,16 @@ struct int_ge_two { friend int_ge_two operator*(positive_int lhs, int_ge_two rhs); friend nonnegative_int operator*(nonnegative_int lhs, int_ge_two rhs); + friend positive_int &operator*=(positive_int &lhs, int_ge_two rhs); + friend nonnegative_int &operator*=(nonnegative_int &lhs, int_ge_two rhs); + + nonnegative_int operator/(int_ge_two other) const; + friend nonnegative_int operator/(positive_int lhs, int_ge_two rhs); + friend nonnegative_int operator/(nonnegative_int lhs, int_ge_two rhs); + + friend nonnegative_int operator%(positive_int lhs, int_ge_two rhs); + friend nonnegative_int operator%(nonnegative_int lhs, int_ge_two rhs); + int int_from_int_ge_two() const; nonnegative_int nonnegative_int_from_int_ge_two() const; positive_int positive_int_from_int_ge_two() const; diff --git a/lib/utils/src/utils/int_ge_two/int_ge_two.cc b/lib/utils/src/utils/int_ge_two/int_ge_two.cc index ae8e0fa42b..e300340b80 100644 --- a/lib/utils/src/utils/int_ge_two/int_ge_two.cc +++ b/lib/utils/src/utils/int_ge_two/int_ge_two.cc @@ -298,6 +298,35 @@ nonnegative_int operator*(nonnegative_int lhs, int_ge_two rhs) { return rhs * lhs; } +positive_int &operator*=(positive_int &lhs, int_ge_two rhs) { + return (lhs *= rhs.positive_int_from_int_ge_two()); +} + +nonnegative_int &operator*=(nonnegative_int &lhs, int_ge_two rhs) { + return (lhs *= rhs.nonnegative_int_from_int_ge_two()); +} + +nonnegative_int int_ge_two::operator/(int_ge_two other) const { + return this->positive_int_from_int_ge_two() / + other.positive_int_from_int_ge_two(); +} + +nonnegative_int operator/(positive_int lhs, int_ge_two rhs) { + return lhs / rhs.positive_int_from_int_ge_two(); +} + +nonnegative_int operator/(nonnegative_int lhs, int_ge_two rhs) { + return lhs / rhs.positive_int_from_int_ge_two(); +} + +nonnegative_int operator%(positive_int lhs, int_ge_two rhs) { + return lhs % rhs.positive_int_from_int_ge_two(); +} + +nonnegative_int operator%(nonnegative_int lhs, int_ge_two rhs) { + return lhs % rhs.positive_int_from_int_ge_two(); +} + int int_ge_two::int_from_int_ge_two() const { return this->value_; } diff --git a/lib/utils/test/src/utils/int_ge_two/int_ge_two.cc b/lib/utils/test/src/utils/int_ge_two/int_ge_two.cc new file mode 100644 index 0000000000..922d4df39b --- /dev/null +++ b/lib/utils/test/src/utils/int_ge_two/int_ge_two.cc @@ -0,0 +1,86 @@ +#include "utils/int_ge_two/int_ge_two.h" +#include + +using namespace ::FlexFlow; + +TEST_SUITE(FF_TEST_SUITE) { + TEST_CASE("int_ge_two") { + SUBCASE("constructor") { + SUBCASE("throws if value is less than 2") { + CHECK_THROWS(int_ge_two{1}); + CHECK_THROWS(int_ge_two{0}); + CHECK_THROWS(int_ge_two{-1}); + } + + SUBCASE("wraps the value if >= 2") { + int_ge_two x = int_ge_two{2}; + + CHECK(x.int_from_int_ge_two() == 2); + } + } + + SUBCASE("positive_int *= int_ge_two") { + positive_int x = 3_p; + x *= int_ge_two{4}; + + positive_int correct = 12_p; + + CHECK(x == correct); + } + + SUBCASE("nonnegative_int *= int_ge_two") { + SUBCASE("starting value is zero") { + nonnegative_int x = 0_n; + x *= int_ge_two{4}; + + nonnegative_int correct = 0_n; + + CHECK(x == correct); + } + + SUBCASE("starting value is nonzero") { + nonnegative_int x = 5_n; + x *= int_ge_two{4}; + + nonnegative_int correct = 20_n; + + CHECK(x == correct); + } + } + + SUBCASE("int_ge_two / int_ge_two") { + nonnegative_int result = int_ge_two{2} / int_ge_two{8}; + nonnegative_int correct = 0_n; + + CHECK(result == correct); + } + + SUBCASE("positive_int / int_ge_two") { + nonnegative_int result = positive_int{4} / int_ge_two{2}; + nonnegative_int correct = 2_n; + + CHECK(result == correct); + } + + SUBCASE("nonnegative_int / int_ge_two") { + nonnegative_int result = nonnegative_int{0} / int_ge_two{2}; + nonnegative_int correct = 0_n; + + CHECK(result == correct); + } + + SUBCASE("positive_int % int_ge_two") { + nonnegative_int result = positive_int{4} % int_ge_two{2}; + nonnegative_int correct = 0_n; + + CHECK(result == correct); + } + + SUBCASE("nonnegative_int % int_ge_two") { + nonnegative_int result = nonnegative_int{3} % int_ge_two{2}; + nonnegative_int correct = 1_n; + + CHECK(result == correct); + } + } +}