Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -381,30 +381,28 @@ private static Hop removeUnnecessaryCumulativeOp(Hop parent, Hop hi, int pos)

return hi;
}

private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos)
{
if( hi instanceof ReorgOp )
{

private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos) {
if( hi instanceof ReorgOp ) {
ReorgOp rop = (ReorgOp) hi;
Hop input = hi.getInput(0);
Hop input = hi.getInput(0);
boolean apply = false;
//equal dims of reshape input and output -> no need for reshape because

//equal dims of reshape input and output -> no need for reshape because
//byrow always refers to both input/output and hence gives the same result
apply |= (rop.getOp()==ReOrgOp.RESHAPE && HopRewriteUtils.isEqualSize(hi, input));
//1x1 dimensions of transpose/reshape -> no need for reorg
apply |= ((rop.getOp()==ReOrgOp.TRANS || rop.getOp()==ReOrgOp.RESHAPE)
&& rop.getDim1()==1 && rop.getDim2()==1);

//1x1 dimensions of transpose/reshape/roll -> no need for reorg
apply |= ((rop.getOp()==ReOrgOp.TRANS || rop.getOp()==ReOrgOp.RESHAPE
|| rop.getOp()==ReOrgOp.ROLL) && rop.getDim1()==1 && rop.getDim2()==1);

if( apply ) {
HopRewriteUtils.replaceChildReference(parent, hi, input, pos);
hi = input;
LOG.debug("Applied removeUnnecessaryReorg.");
}
}

return hi;
}

Expand Down Expand Up @@ -1356,44 +1354,78 @@ else if ( applyRight ) {
* @param pos position
* @return high-level operator
*/
private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos)
private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos)
{
//all patterns headed by full sum over binary operation
if( hi instanceof AggUnaryOp //full sum root over binaryop
&& ((AggUnaryOp)hi).getDirection()==Direction.RowCol
&& ((AggUnaryOp)hi).getOp() == AggOp.SUM
&& hi.getInput(0) instanceof BinaryOp
&& hi.getInput(0).getParent().size()==1 ) //single parent
&& ((AggUnaryOp)hi).getDirection()==Direction.RowCol
&& ((AggUnaryOp)hi).getOp() == AggOp.SUM
&& hi.getInput(0) instanceof BinaryOp
&& hi.getInput(0).getParent().size()==1 ) //single parent
{
BinaryOp bop = (BinaryOp) hi.getInput(0);
Hop left = bop.getInput(0);
Hop right = bop.getInput(1);

if( HopRewriteUtils.isEqualSize(left, right) //dims(A) == dims(B)
&& left.getDataType() == DataType.MATRIX
&& right.getDataType() == DataType.MATRIX )

if( left.getDataType() == DataType.MATRIX
&& right.getDataType() == DataType.MATRIX )
{
OpOp2 applyOp = ( bop.getOp() == OpOp2.PLUS //pattern a: sum(A+B)->sum(A)+sum(B)
|| bop.getOp() == OpOp2.MINUS ) //pattern b: sum(A-B)->sum(A)-sum(B)
? bop.getOp() : null;

if( applyOp != null ) {
//create new subdag sum(A) bop sum(B)
AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
AggUnaryOp sum2 = HopRewriteUtils.createSum(right);
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp);

//rewire new subdag
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
HopRewriteUtils.cleanupUnreferenced(hi, bop);

hi = newBin;

LOG.debug("Applied pushdownSumOnAdditiveBinary (line "+hi.getBeginLine()+").");
if (HopRewriteUtils.isEqualSize(left, right)) {
//create new subdag sum(A) bop sum(B) for equal-sized matrices
AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
AggUnaryOp sum2 = HopRewriteUtils.createSum(right);
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp);
//rewire new subdag
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
HopRewriteUtils.cleanupUnreferenced(hi, bop);

hi = newBin;

LOG.debug("Applied pushdownSumOnAdditiveBinary (line "+hi.getBeginLine()+").");
}
// Check if right operand is a vector (has dimension of 1 in either rows or columns)
else if (right.getDim1() == 1 || right.getDim2() == 1) {
AggUnaryOp sum1 = HopRewriteUtils.createSum(left);
AggUnaryOp sum2 = HopRewriteUtils.createSum(right);

// Row vector case (1 x n)
if (right.getDim1() == 1) {
// Create nrow(A) operation using dimensions
LiteralOp nRows = new LiteralOp(left.getDim1());
BinaryOp scaledSum = HopRewriteUtils.createBinary(nRows, sum2, OpOp2.MULT);
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, scaledSum, applyOp);
//rewire new subdag
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
HopRewriteUtils.cleanupUnreferenced(hi, bop);

hi = newBin;

LOG.debug("Applied pushdownSumOnAdditiveBinary with row vector (line "+hi.getBeginLine()+").");
}
// Column vector case (n x 1)
else if (right.getDim2() == 1) {
// Create ncol(A) operation using dimensions
LiteralOp nCols = new LiteralOp(left.getDim2());
BinaryOp scaledSum = HopRewriteUtils.createBinary(nCols, sum2, OpOp2.MULT);
BinaryOp newBin = HopRewriteUtils.createBinary(sum1, scaledSum, applyOp);
//rewire new subdag
HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos);
HopRewriteUtils.cleanupUnreferenced(hi, bop);

hi = newBin;

LOG.debug("Applied pushdownSumOnAdditiveBinary with column vector (line "+hi.getBeginLine()+").");
}
}
}
}
}

return hi;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,54 +29,93 @@
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;

public class RewritePushdownSumOnBinaryTest extends AutomatedTestBase
public class RewritePushdownSumOnBinaryTest extends AutomatedTestBase
{
private static final String TEST_NAME1 = "RewritePushdownSumOnBinary";
private static final String TEST_DIR = "functions/rewrite/";
private static final String TEST_CLASS_DIR = TEST_DIR + RewritePushdownSumOnBinaryTest.class.getSimpleName() + "/";

private static final int rows = 1000;
private static final int cols = 1;

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R1", "R2" }) );
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1,
new String[] { "R1", "R2", "R3", "R4" }));
}

@Test
public void testRewritePushdownSumOnBinaryNoRewrite() {
testRewritePushdownSumOnBinary(TEST_NAME1, false);
}

@Test
public void testRewritePushdownSumOnBinary() {
testRewritePushdownSumOnBinary(TEST_NAME1, true);
}

@Test
public void testRewritePushdownSumOnBinaryNoRewrite() {
testRewritePushdownSumOnBinary( TEST_NAME1, false );
public void testRewritePushdownSumOnBinaryRowVector() {
testRewritePushdownSumOnBinaryVector(TEST_NAME1, true, true);
}

@Test
public void testRewritePushdownSumOnBinary() {
testRewritePushdownSumOnBinary( TEST_NAME1, true );
public void testRewritePushdownSumOnBinaryColVector() {
testRewritePushdownSumOnBinaryVector(TEST_NAME1, true, false);
}

private void testRewritePushdownSumOnBinary( String testname, boolean rewrites )
{

private void testRewritePushdownSumOnBinary(String testname, boolean rewrites) {
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;

try {
TestConfiguration config = getTestConfiguration(testname);
loadTestConfiguration(config);

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[]{ "-args", String.valueOf(rows),
String.valueOf(cols), output("R1"), output("R2") };

programArgs = new String[]{ "-args", String.valueOf(rows),
String.valueOf(cols), output("R1"), output("R2"),
String.valueOf(rows), String.valueOf(cols) }; // Assuming row and col vectors

OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;

//run performance tests
// Run performance tests
runTest(true, false, null, -1);
//compare matrices
long expect = Math.round(0.5*rows);

// Compare matrices
long expect = Math.round(0.5 * rows);
HashMap<CellIndex, Double> dmlfile1 = readDMLScalarFromOutputDir("R1");
Assert.assertEquals(expect, dmlfile1.get(new CellIndex(1,1)), expect*0.01);
Assert.assertEquals(expect, dmlfile1.get(new CellIndex(1, 1)), expect * 0.01);
HashMap<CellIndex, Double> dmlfile2 = readDMLScalarFromOutputDir("R2");
Assert.assertEquals(expect, dmlfile2.get(new CellIndex(1,1)), expect*0.01);
Assert.assertEquals(expect, dmlfile2.get(new CellIndex(1, 1)), expect * 0.01);
} finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
}
}


private void testRewritePushdownSumOnBinaryVector(String testname, boolean rewrites, boolean isRow) {
boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
try {
TestConfiguration config = getTestConfiguration(testname);
loadTestConfiguration(config);

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[]{ "-args", String.valueOf(rows),
String.valueOf(cols), output("R3"), output("R4"),
String.valueOf(isRow ? 1 : rows), String.valueOf(isRow ? cols : 1) };
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites;

runTest(true, false, null, -1);

long expect = Math.round(500); // Expected value for 0.5 + 0.5
HashMap<CellIndex, Double> dmlfile3 = readDMLScalarFromOutputDir("R3");
Assert.assertEquals(expect, dmlfile3.get(new CellIndex(1,1)), expect*0.01);
HashMap<CellIndex, Double> dmlfile4 = readDMLScalarFromOutputDir("R4");
Assert.assertEquals(expect, dmlfile4.get(new CellIndex(1,1)), expect*0.01);
}
finally {
OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
Expand Down
27 changes: 21 additions & 6 deletions src/test/scripts/functions/rewrite/RewritePushdownSumOnBinary.dml
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,30 @@
#
#-------------------------------------------------------------

A = rand(rows=$1, cols=$2, seed=1);
B = rand(rows=$1, cols=$2, seed=2);
C = rand(rows=$1, cols=$2, seed=3);
D = rand(rows=$1, cols=$2, seed=4);
# Required parameters
A = matrix(0.5, rows=$1, cols=$2);
B = matrix(0.5, rows=$1, cols=$2);
C = matrix(0.5, rows=$1, cols=$2);
D = matrix(0.5, rows=$1, cols=$2);

# Set defaults for optional parameters
rowsV = ifdef($5, 0)
colsV = ifdef($6, 0)

# Original matrix tests
r1 = sum(A*B + C*D);
r2 = r1;

print("r1="+r1+", r2="+r2);
# Vector tests
if (rowsV != 0 & colsV != 0) {
V = matrix(0.5, rows=rowsV, cols=colsV);
r3 = sum(A + V);
r4 = r3;
}

write(r1, $3);
write(r2, $4);

if (rowsV != 0 & colsV != 0) {
write(r3, $5);
write(r4, $6);
}
Loading