diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java index d73f8489b60..ddb2252f512 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java @@ -381,30 +381,28 @@ private static Hop removeUnnecessaryCumulativeOp(Hop parent, Hop hi, int pos) return hi; } - - private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos) - { - if( hi instanceof ReorgOp ) - { + + private static Hop removeUnnecessaryReorgOperation(Hop parent, Hop hi, int pos) { + if( hi instanceof ReorgOp ) { ReorgOp rop = (ReorgOp) hi; - Hop input = hi.getInput(0); + Hop input = hi.getInput(0); boolean apply = false; - - //equal dims of reshape input and output -> no need for reshape because + + //equal dims of reshape input and output -> no need for reshape because //byrow always refers to both input/output and hence gives the same result apply |= (rop.getOp()==ReOrgOp.RESHAPE && HopRewriteUtils.isEqualSize(hi, input)); - - //1x1 dimensions of transpose/reshape -> no need for reorg - apply |= ((rop.getOp()==ReOrgOp.TRANS || rop.getOp()==ReOrgOp.RESHAPE) - && rop.getDim1()==1 && rop.getDim2()==1); - + + //1x1 dimensions of transpose/reshape/roll -> no need for reorg + apply |= ((rop.getOp()==ReOrgOp.TRANS || rop.getOp()==ReOrgOp.RESHAPE + || rop.getOp()==ReOrgOp.ROLL) && rop.getDim1()==1 && rop.getDim2()==1); + if( apply ) { HopRewriteUtils.replaceChildReference(parent, hi, input, pos); hi = input; LOG.debug("Applied removeUnnecessaryReorg."); } } - + return hi; } @@ -1356,44 +1354,78 @@ else if ( applyRight ) { * @param pos position * @return high-level operator */ - private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos) + private static Hop pushdownSumOnAdditiveBinary(Hop parent, Hop hi, int pos) { //all patterns headed by full sum over binary operation if( hi instanceof AggUnaryOp //full sum root over binaryop - && ((AggUnaryOp)hi).getDirection()==Direction.RowCol - && ((AggUnaryOp)hi).getOp() == AggOp.SUM - && hi.getInput(0) instanceof BinaryOp - && hi.getInput(0).getParent().size()==1 ) //single parent + && ((AggUnaryOp)hi).getDirection()==Direction.RowCol + && ((AggUnaryOp)hi).getOp() == AggOp.SUM + && hi.getInput(0) instanceof BinaryOp + && hi.getInput(0).getParent().size()==1 ) //single parent { BinaryOp bop = (BinaryOp) hi.getInput(0); Hop left = bop.getInput(0); Hop right = bop.getInput(1); - - if( HopRewriteUtils.isEqualSize(left, right) //dims(A) == dims(B) - && left.getDataType() == DataType.MATRIX - && right.getDataType() == DataType.MATRIX ) + + if( left.getDataType() == DataType.MATRIX + && right.getDataType() == DataType.MATRIX ) { OpOp2 applyOp = ( bop.getOp() == OpOp2.PLUS //pattern a: sum(A+B)->sum(A)+sum(B) || bop.getOp() == OpOp2.MINUS ) //pattern b: sum(A-B)->sum(A)-sum(B) ? bop.getOp() : null; - + if( applyOp != null ) { - //create new subdag sum(A) bop sum(B) - AggUnaryOp sum1 = HopRewriteUtils.createSum(left); - AggUnaryOp sum2 = HopRewriteUtils.createSum(right); - BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp); - - //rewire new subdag - HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos); - HopRewriteUtils.cleanupUnreferenced(hi, bop); - - hi = newBin; - - LOG.debug("Applied pushdownSumOnAdditiveBinary (line "+hi.getBeginLine()+")."); + if (HopRewriteUtils.isEqualSize(left, right)) { + //create new subdag sum(A) bop sum(B) for equal-sized matrices + AggUnaryOp sum1 = HopRewriteUtils.createSum(left); + AggUnaryOp sum2 = HopRewriteUtils.createSum(right); + BinaryOp newBin = HopRewriteUtils.createBinary(sum1, sum2, applyOp); + //rewire new subdag + HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos); + HopRewriteUtils.cleanupUnreferenced(hi, bop); + + hi = newBin; + + LOG.debug("Applied pushdownSumOnAdditiveBinary (line "+hi.getBeginLine()+")."); + } + // Check if right operand is a vector (has dimension of 1 in either rows or columns) + else if (right.getDim1() == 1 || right.getDim2() == 1) { + AggUnaryOp sum1 = HopRewriteUtils.createSum(left); + AggUnaryOp sum2 = HopRewriteUtils.createSum(right); + + // Row vector case (1 x n) + if (right.getDim1() == 1) { + // Create nrow(A) operation using dimensions + LiteralOp nRows = new LiteralOp(left.getDim1()); + BinaryOp scaledSum = HopRewriteUtils.createBinary(nRows, sum2, OpOp2.MULT); + BinaryOp newBin = HopRewriteUtils.createBinary(sum1, scaledSum, applyOp); + //rewire new subdag + HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos); + HopRewriteUtils.cleanupUnreferenced(hi, bop); + + hi = newBin; + + LOG.debug("Applied pushdownSumOnAdditiveBinary with row vector (line "+hi.getBeginLine()+")."); + } + // Column vector case (n x 1) + else if (right.getDim2() == 1) { + // Create ncol(A) operation using dimensions + LiteralOp nCols = new LiteralOp(left.getDim2()); + BinaryOp scaledSum = HopRewriteUtils.createBinary(nCols, sum2, OpOp2.MULT); + BinaryOp newBin = HopRewriteUtils.createBinary(sum1, scaledSum, applyOp); + //rewire new subdag + HopRewriteUtils.replaceChildReference(parent, hi, newBin, pos); + HopRewriteUtils.cleanupUnreferenced(hi, bop); + + hi = newBin; + + LOG.debug("Applied pushdownSumOnAdditiveBinary with column vector (line "+hi.getBeginLine()+")."); + } + } } } } - + return hi; } diff --git a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumOnBinaryTest.java b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumOnBinaryTest.java index 9391af719a1..a8778b2b851 100644 --- a/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumOnBinaryTest.java +++ b/src/test/java/org/apache/sysds/test/functions/rewrite/RewritePushdownSumOnBinaryTest.java @@ -29,54 +29,93 @@ import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.test.TestUtils; -public class RewritePushdownSumOnBinaryTest extends AutomatedTestBase +public class RewritePushdownSumOnBinaryTest extends AutomatedTestBase { private static final String TEST_NAME1 = "RewritePushdownSumOnBinary"; private static final String TEST_DIR = "functions/rewrite/"; private static final String TEST_CLASS_DIR = TEST_DIR + RewritePushdownSumOnBinaryTest.class.getSimpleName() + "/"; - + private static final int rows = 1000; private static final int cols = 1; - + @Override public void setUp() { TestUtils.clearAssertionInformation(); - addTestConfiguration( TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] { "R1", "R2" }) ); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, + new String[] { "R1", "R2", "R3", "R4" })); + } + + @Test + public void testRewritePushdownSumOnBinaryNoRewrite() { + testRewritePushdownSumOnBinary(TEST_NAME1, false); + } + + @Test + public void testRewritePushdownSumOnBinary() { + testRewritePushdownSumOnBinary(TEST_NAME1, true); } @Test - public void testRewritePushdownSumOnBinaryNoRewrite() { - testRewritePushdownSumOnBinary( TEST_NAME1, false ); + public void testRewritePushdownSumOnBinaryRowVector() { + testRewritePushdownSumOnBinaryVector(TEST_NAME1, true, true); } - + @Test - public void testRewritePushdownSumOnBinary() { - testRewritePushdownSumOnBinary( TEST_NAME1, true ); + public void testRewritePushdownSumOnBinaryColVector() { + testRewritePushdownSumOnBinaryVector(TEST_NAME1, true, false); } - - private void testRewritePushdownSumOnBinary( String testname, boolean rewrites ) - { + + private void testRewritePushdownSumOnBinary(String testname, boolean rewrites) { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; - + try { TestConfiguration config = getTestConfiguration(testname); loadTestConfiguration(config); - + String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + testname + ".dml"; - programArgs = new String[]{ "-args", String.valueOf(rows), - String.valueOf(cols), output("R1"), output("R2") }; + + programArgs = new String[]{ "-args", String.valueOf(rows), + String.valueOf(cols), output("R1"), output("R2"), + String.valueOf(rows), String.valueOf(cols) }; // Assuming row and col vectors + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; - //run performance tests + // Run performance tests runTest(true, false, null, -1); - - //compare matrices - long expect = Math.round(0.5*rows); + + // Compare matrices + long expect = Math.round(0.5 * rows); HashMap dmlfile1 = readDMLScalarFromOutputDir("R1"); - Assert.assertEquals(expect, dmlfile1.get(new CellIndex(1,1)), expect*0.01); + Assert.assertEquals(expect, dmlfile1.get(new CellIndex(1, 1)), expect * 0.01); HashMap dmlfile2 = readDMLScalarFromOutputDir("R2"); - Assert.assertEquals(expect, dmlfile2.get(new CellIndex(1,1)), expect*0.01); + Assert.assertEquals(expect, dmlfile2.get(new CellIndex(1, 1)), expect * 0.01); + } finally { + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; + } + } + + + private void testRewritePushdownSumOnBinaryVector(String testname, boolean rewrites, boolean isRow) { + boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; + try { + TestConfiguration config = getTestConfiguration(testname); + loadTestConfiguration(config); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + testname + ".dml"; + programArgs = new String[]{ "-args", String.valueOf(rows), + String.valueOf(cols), output("R3"), output("R4"), + String.valueOf(isRow ? 1 : rows), String.valueOf(isRow ? cols : 1) }; + OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = rewrites; + + runTest(true, false, null, -1); + + long expect = Math.round(500); // Expected value for 0.5 + 0.5 + HashMap dmlfile3 = readDMLScalarFromOutputDir("R3"); + Assert.assertEquals(expect, dmlfile3.get(new CellIndex(1,1)), expect*0.01); + HashMap dmlfile4 = readDMLScalarFromOutputDir("R4"); + Assert.assertEquals(expect, dmlfile4.get(new CellIndex(1,1)), expect*0.01); } finally { OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag; diff --git a/src/test/scripts/functions/rewrite/RewritePushdownSumOnBinary.dml b/src/test/scripts/functions/rewrite/RewritePushdownSumOnBinary.dml index d48ac0aad82..0d1b8123978 100644 --- a/src/test/scripts/functions/rewrite/RewritePushdownSumOnBinary.dml +++ b/src/test/scripts/functions/rewrite/RewritePushdownSumOnBinary.dml @@ -19,15 +19,30 @@ # #------------------------------------------------------------- -A = rand(rows=$1, cols=$2, seed=1); -B = rand(rows=$1, cols=$2, seed=2); -C = rand(rows=$1, cols=$2, seed=3); -D = rand(rows=$1, cols=$2, seed=4); +# Required parameters +A = matrix(0.5, rows=$1, cols=$2); +B = matrix(0.5, rows=$1, cols=$2); +C = matrix(0.5, rows=$1, cols=$2); +D = matrix(0.5, rows=$1, cols=$2); +# Set defaults for optional parameters +rowsV = ifdef($5, 0) +colsV = ifdef($6, 0) + +# Original matrix tests r1 = sum(A*B + C*D); r2 = r1; -print("r1="+r1+", r2="+r2); +# Vector tests +if (rowsV != 0 & colsV != 0) { + V = matrix(0.5, rows=rowsV, cols=colsV); + r3 = sum(A + V); + r4 = r3; +} + write(r1, $3); write(r2, $4); - +if (rowsV != 0 & colsV != 0) { + write(r3, $5); + write(r4, $6); +} \ No newline at end of file