Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions lib/utils/include/utils/int_ge_two/int_ge_two.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ struct int_ge_two {
friend int_ge_two operator*(positive_int lhs, int_ge_two rhs);
friend nonnegative_int operator*(nonnegative_int lhs, int_ge_two rhs);

friend positive_int &operator*=(positive_int &lhs, int_ge_two rhs);
friend nonnegative_int &operator*=(nonnegative_int &lhs, int_ge_two rhs);

nonnegative_int operator/(int_ge_two other) const;
friend nonnegative_int operator/(positive_int lhs, int_ge_two rhs);
friend nonnegative_int operator/(nonnegative_int lhs, int_ge_two rhs);

friend nonnegative_int operator%(positive_int lhs, int_ge_two rhs);
friend nonnegative_int operator%(nonnegative_int lhs, int_ge_two rhs);

int int_from_int_ge_two() const;
nonnegative_int nonnegative_int_from_int_ge_two() const;
positive_int positive_int_from_int_ge_two() const;
Expand Down
29 changes: 29 additions & 0 deletions lib/utils/src/utils/int_ge_two/int_ge_two.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,35 @@ nonnegative_int operator*(nonnegative_int lhs, int_ge_two rhs) {
return rhs * lhs;
}

positive_int &operator*=(positive_int &lhs, int_ge_two rhs) {
return (lhs *= rhs.positive_int_from_int_ge_two());
}

nonnegative_int &operator*=(nonnegative_int &lhs, int_ge_two rhs) {
return (lhs *= rhs.nonnegative_int_from_int_ge_two());
}

nonnegative_int int_ge_two::operator/(int_ge_two other) const {
return this->positive_int_from_int_ge_two() /
other.positive_int_from_int_ge_two();
}

nonnegative_int operator/(positive_int lhs, int_ge_two rhs) {
return lhs / rhs.positive_int_from_int_ge_two();
}

nonnegative_int operator/(nonnegative_int lhs, int_ge_two rhs) {
return lhs / rhs.positive_int_from_int_ge_two();
}

nonnegative_int operator%(positive_int lhs, int_ge_two rhs) {
return lhs % rhs.positive_int_from_int_ge_two();
}

nonnegative_int operator%(nonnegative_int lhs, int_ge_two rhs) {
return lhs % rhs.positive_int_from_int_ge_two();
}

int int_ge_two::int_from_int_ge_two() const {
return this->value_;
}
Expand Down
86 changes: 86 additions & 0 deletions lib/utils/test/src/utils/int_ge_two/int_ge_two.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#include "utils/int_ge_two/int_ge_two.h"
#include <doctest/doctest.h>

using namespace ::FlexFlow;

TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("int_ge_two") {
SUBCASE("constructor") {
SUBCASE("throws if value is less than 2") {
CHECK_THROWS(int_ge_two{1});
CHECK_THROWS(int_ge_two{0});
CHECK_THROWS(int_ge_two{-1});
}

SUBCASE("wraps the value if >= 2") {
int_ge_two x = int_ge_two{2};

CHECK(x.int_from_int_ge_two() == 2);
}
}

SUBCASE("positive_int *= int_ge_two") {
positive_int x = 3_p;
x *= int_ge_two{4};

positive_int correct = 12_p;

CHECK(x == correct);
}

SUBCASE("nonnegative_int *= int_ge_two") {
SUBCASE("starting value is zero") {
nonnegative_int x = 0_n;
x *= int_ge_two{4};

nonnegative_int correct = 0_n;

CHECK(x == correct);
}

SUBCASE("starting value is nonzero") {
nonnegative_int x = 5_n;
x *= int_ge_two{4};

nonnegative_int correct = 20_n;

CHECK(x == correct);
}
}

SUBCASE("int_ge_two / int_ge_two") {
nonnegative_int result = int_ge_two{2} / int_ge_two{8};
nonnegative_int correct = 0_n;

CHECK(result == correct);
}

SUBCASE("positive_int / int_ge_two") {
nonnegative_int result = positive_int{4} / int_ge_two{2};
nonnegative_int correct = 2_n;

CHECK(result == correct);
}

SUBCASE("nonnegative_int / int_ge_two") {
nonnegative_int result = nonnegative_int{0} / int_ge_two{2};
nonnegative_int correct = 0_n;

CHECK(result == correct);
}

SUBCASE("positive_int % int_ge_two") {
nonnegative_int result = positive_int{4} % int_ge_two{2};
nonnegative_int correct = 0_n;

CHECK(result == correct);
}

SUBCASE("nonnegative_int % int_ge_two") {
nonnegative_int result = nonnegative_int{3} % int_ge_two{2};
nonnegative_int correct = 1_n;

CHECK(result == correct);
}
}
}
Loading