diff --git a/scripts/builtin/matrixSqrt.dml b/scripts/builtin/matrixSqrt.dml new file mode 100644 index 00000000000..f1ea468a2ca --- /dev/null +++ b/scripts/builtin/matrixSqrt.dml @@ -0,0 +1,131 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +matrixSqrt = function( + Matrix[Double] X +)return( + Matrix[Double] sqrt_x +){ + N = nrow(X); + D = ncol(X); + + #check that matrix is square + if (D != N){ + stop("matrixSqrt Input Error: matrix not square!") + } + + # Any non singualar square matrix has a square root + isDiag = isDiagonal(X) + if(isDiag) { + sqrt_x = sqrtDiagMatrix(X); + } else { + [eValues, eVectors] = eigen(X); + + hasNonNegativeEigenValues = TRUE + l = length(eValues) + for (i in 1:l){ + gtZero = as.scalar(eValues[i]) >= 0.0; + hasNonNegativeEigenValues = gtZero & hasNonNegativeEigenValues; + } + + if(!hasNonNegativeEigenValues) { + stop("matrixSqrt exec Error: matrix has imaginary square root"); + } + + X_t = t(X); + isSymmetric = TRUE; + + for (i in 1:N) { + for (j in 1:D) { + same = as.scalar(X[i, j]) == as.scalar(X_t[i, j]); + isSymmetric = isSymmetric & same; + } + } + + allEigenValuesUnique = length(eValues) == length(unique(eValues)); + + if(allEigenValuesUnique | isSymmetric) { + # calculate X = VDV^(-1) -> S = sqrt(D) -> sqrt_x = VSV^(-1) + sqrtD = sqrtDiagMatrix(diag(eValues)); + V_Inv = inv(eVectors); + sqrt_x = eVectors %*% sqrtD %*% V_Inv; + } else { + #formular: (Denman–Beavers iteration) + Y = X + #identity matrix + Z = diag(matrix(1.0, rows=N, cols=1)) + + for (x in 1:100) { + Y_new = (1 / 2) * (Y + inv(Z)) + Z_new = (1 / 2) * (Z + inv(Y)) + Y = Y_new + Z = Z_new + } + sqrt_x = Y + } + } +} + +# assumes square and diagonal matrix +sqrtDiagMatrix = function( + Matrix[Double] X +)return( + Matrix[Double] sqrt_x +){ + N = nrow(X); + + #check if identity matrix + is_identity = TRUE; + for (i in 1:N) { + is_idElement = as.scalar(X[i, i]) == 1.0; + is_identity = is_identity & is_idElement; + } + + if(is_identity) { + sqrt_x = X; + } else { + sqrt_x = matrix(0, rows=N, cols=N); + tmp = 0 + for (i in 1:N) { + #workaround needed to access variable for it to be initialized + tmp = tmp + i + value = X[i, i]; + sqrt_x[i, i] = sqrt(value); + } + } +} + +isDiagonal = function ( + Matrix[Double] X +)return( + boolean diagonal +){ + N = nrow(X); + D = ncol(X); + noCells = N * D; + + diag = diag(diag(X)); + compare = X == diag; + sameCells = sum(compare); + + #all cells should be the same to be diagonal + diagonal = noCells == sameCells; +} \ No newline at end of file diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index a6331905ac4..8e3239dc985 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -325,6 +325,8 @@ public enum Builtins { STEPLM("steplm",true, ReturnType.MULTI_RETURN), STFT("stft", false, ReturnType.MULTI_RETURN), SQRT("sqrt", false), + SQRT_MATRIX("matrixSqrt", true), + SQRT_MATRIX_JAVA("sqrt_matrix_java", false, ReturnType.SINGLE_RETURN), SUM("sum", false), SVD("svd", false, ReturnType.MULTI_RETURN), TABLE("table", "ctable", false), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index dd351ae894c..21595efd03b 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -542,7 +542,7 @@ public enum OpOp1 { CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, INVERSE, IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW, MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, _EVICT, - SVD, TAN, TANH, TYPEOF, TRIGREMOTE, + SVD, TAN, TANH, TYPEOF, TRIGREMOTE, SQRT_MATRIX_JAVA, //fused ML-specific operators for performance SPROP, //sample proportion: P * (1 - P) SIGMOID, //sigmoid function: 1 / (1 + exp(-X)) diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java index 2c0cd4a61ba..1bda77530bb 100644 --- a/src/main/java/org/apache/sysds/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java @@ -512,7 +512,7 @@ && getInput().get(0).getParent().size()==1 ) //unary is only parent //ensure cp exec type for single-node operations if( _op == OpOp1.PRINT || _op == OpOp1.ASSERT || _op == OpOp1.STOP || _op == OpOp1.TYPEOF - || _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == OpOp1.CHOLESKY || _op == OpOp1.SVD + || _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == OpOp1.CHOLESKY || _op == OpOp1.SVD || _op == OpOp1.SQRT_MATRIX_JAVA || getInput().get(0).getDataType() == DataType.LIST || isMetadataOperation() ) { _etype = ExecType.CP; diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 1de3442dd9d..c12e4c4705f 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -1759,6 +1759,19 @@ && isConstant(in[2]) output.setDimensions(in.getDim1(), in.getDim2()); output.setBlocksize(in.getBlocksize()); break; + + case SQRT_MATRIX_JAVA: + + checkNumParameters(1); + checkMatrixParam(getFirstExpr()); + output.setDataType(DataType.MATRIX); + output.setValueType(ValueType.FP64); + Identifier sqrt = getFirstExpr().getOutput(); + if(sqrt.dimsKnown() && sqrt.getDim1() != sqrt.getDim2()) + raiseValidateError("Input to sqrtMatrix() must be square matrix -- given: a " + sqrt.getDim1() + "x" + sqrt.getDim2() + " matrix.", conditional); + output.setDimensions( sqrt.getDim1(), sqrt.getDim2()); + output.setBlocksize( sqrt.getBlocksize()); + break; case CHOLESKY: { diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index 6121711933a..b0673be092e 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2749,6 +2749,7 @@ else if ( in.length == 2 ) break; case INVERSE: + case SQRT_MATRIX_JAVA: case CHOLESKY: case TYPEOF: case DETECTSCHEMA: diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java index a0270f6b20c..2d19b39f8a7 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java @@ -208,6 +208,7 @@ public class CPInstructionParser extends InstructionParser { String2CPInstructionType.put( "ucummax", CPType.Unary); String2CPInstructionType.put( "stop" , CPType.Unary); String2CPInstructionType.put( "inverse", CPType.Unary); + String2CPInstructionType.put( "sqrt_matrix_java", CPType.Unary); String2CPInstructionType.put( "cholesky",CPType.Unary); String2CPInstructionType.put( "sprop", CPType.Unary); String2CPInstructionType.put( "sigmoid", CPType.Unary); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java index 61a5f0d7842..5365944a3be 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibCommonsMath.java @@ -80,7 +80,7 @@ private LibCommonsMath() { } public static boolean isSupportedUnaryOperation( String opcode ) { - return ( opcode.equals("inverse") || opcode.equals("cholesky") ); + return ( opcode.equals("inverse") || opcode.equals("cholesky") || opcode.equals("sqrt_matrix_java") ); } public static boolean isSupportedMultiReturnOperation( String opcode ) { @@ -111,6 +111,8 @@ public static MatrixBlock unaryOperations(MatrixBlock inj, String opcode) { return computeMatrixInverse(matrixInput); else if (opcode.equals("cholesky")) return computeCholesky(matrixInput); + else if (opcode.equals("sqrt_matrix_java")) + return computeSqrt(inj); return null; } @@ -512,7 +514,19 @@ private static MatrixBlock[] computeSvd(MatrixBlock in) { return new MatrixBlock[] { U, Sigma, V }; } - + + /** + * Computes the square root of a matrix Calls Apache Commons Math EigenDecomposition. + * + * @param in Input matrix + * @return matrix block + */ + private static MatrixBlock computeSqrt(MatrixBlock in) { + Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(in); + EigenDecomposition ed = new EigenDecomposition(matrixInput); + return DataConverter.convertToMatrixBlock(ed.getSquareRoot()); + } + /** * Function to compute matrix inverse via matrix decomposition. * diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSQRTMatrixTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSQRTMatrixTest.java new file mode 100644 index 00000000000..b8178d325f6 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part2/BuiltinSQRTMatrixTest.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.builtin.part2; + +import org.apache.sysds.common.Types; +import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.runtime.matrix.data.MatrixValue; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +import java.util.HashMap; + +public class BuiltinSQRTMatrixTest extends AutomatedTestBase { + private final static String TEST_NAME = "SQRTMatrix"; + private final static String TEST_DIR = "functions/builtin/"; + private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinSQRTMatrixTest.class.getSimpleName() + "/"; + + private final static double eps = 1e-8; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"C"})); + } + + // tests for strategy "COMMON" + @Test + public void testSQRTMatrixJavaSize1x1() { + runSQRTMatrix(true, ExecType.CP, "COMMON", 1); + } + + @Test + public void testSQRTMatrixJavaUpperTriangularMatrixSize2x2() { + runSQRTMatrix(true, ExecType.CP, "COMMON", 2); + } + + + @Test + public void testSQRTMatrixJavaDiagonalMatrixSize2x2() { + runSQRTMatrix(true, ExecType.CP, "COMMON", 3); + } + + @Test + public void testSQRTMatrixJavaPSDMatrixSize2x2() { + runSQRTMatrix(true, ExecType.CP, "COMMON", 4); + } + + @Test + public void testSQRTMatrixJavaPSDMatrixSize3x3() { + runSQRTMatrix(true, ExecType.CP, "COMMON", 5); + } + + @Test + public void testSQRTMatrixJavaPSDMatrixSize4x4() { + runSQRTMatrix(true, ExecType.CP, "COMMON", 6); + } + + @Test + public void testSQRTMatrixJavaPSDMatrixSize8x8() { + runSQRTMatrix(true, ExecType.CP, "COMMON", 7); + } + + // tests for strategy "DML" + @Test + public void testSQRTMatrixDMLSize1x1() { + runSQRTMatrix(true, ExecType.CP, "DML", 1); + } + + @Test + public void testSQRTMatrixDMLUpperTriangularMatrixSize2x2() { + runSQRTMatrix(true, ExecType.CP, "DML", 2); + } + + @Test + public void testSQRTMatrixDMLDiagonalMatrixSize2x2() { + runSQRTMatrix(true, ExecType.CP, "DML", 3); + } + + @Test + public void testSQRTMatrixDMLPSDMatrixSize2x2() { + runSQRTMatrix(true, ExecType.CP, "DML", 4); + } + + @Test + public void testSQRTMatrixDMLPSDMatrixSize3x3() { + runSQRTMatrix(true, ExecType.CP, "DML", 5); + } + + @Test + public void testSQRTMatrixDMLPSDMatrixSize4x4() { + runSQRTMatrix(true, ExecType.CP, "DML", 6); + } + + @Test + public void testSQRTMatrixDMLPSDMatrixSize8x8() { + runSQRTMatrix(true, ExecType.CP, "DML", 7); + } + + + private void runSQRTMatrix(boolean defaultProb, ExecType instType, String strategy, int test_case) { + Types.ExecMode platformOld = setExecMode(instType); + try { + loadTestConfiguration(getTestConfiguration(TEST_NAME)); + + // find path to associated dml script and define parameters + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME + ".dml"; + programArgs = new String[] {"-args", input("X"), strategy, output("Y")}; + + // define input matrix for the matrix sqrt function according to test case + double[][] X = null; + switch(test_case) { + case 1: // arbitrary square matrix of dimension 1x1 (PSD) + double[][] X1 = { + {4} + }; + X = X1; + break; + case 2: // arbitrary upper right triangular matrix (PSD) of dimension 2x2 + double[][] X2 = { + {1, 1}, + {1, 1}, + }; + X = X2; + break; + case 3: // arbitrary diagonal matrix (PSD) of dimension 2x2 + double[][] X3 = { + {1, 0}, + {0, 1}, + }; + X = X3; + break; + case 4: // arbitrary PSD matrix of dimension 2x2 + // PSD matrix generated by taking (A^T)A of matrix A = [[1, 0], [2, 3]] + double[][] X4 = { + {1, 2}, + {2, 13} + }; + X = X4; + break; + case 5: // arbitrary PSD matrix of dimension 3x3 + // PSD matrix generated by taking (A^T)A of matrix A = + // [[1.5, 0, 1.2], + // [2.2, 3.8, 4.4], + // [4.2, 6.1, 0.2]] + double[][] X5 = { + {3.69, 8.58, 6.54}, + {8.58, 38.64, 33.30}, + {6.54, 33.3, 54.89} + }; + X = X5; + break; + case 6: // arbitrary PSD matrix of dimension 4x4 + // PSD matrix generated by taking (A^T)A of matrix A= + // [[1, 0, 5, 6], + // [2, 3, 0, 2], + // [5, 0, 1, 1], + // [2, 3, 4, 8]] + double[][] X6 = { + {62, 14, 16, 70}, + {14, 17, 12, 29}, + {16, 12, 27, 22}, + {70, 29, 22, 93} + }; + X = X6; + break; + case 7: // arbitrary PSD matrix of dimension 8x8 + // PSD matrix generated by taking (A^T)A of matrix A = + // [[ 8.41557894, 3.44748042, 1.44911908, 4.95381036, 4.42875187, 4.14710712, -0.42719386, 6.1366026 ], + // [ 3.44748042, 11.38083039, 4.99475137, 3.36734826, 4.08943809, 4.23308448, 4.50030176, 3.92552912], + // [ 1.44911908, 4.99475137, 9.78651357, 4.00347878, 4.60244914, 4.24468227, 3.62945751, 6.54033601], + // [ 4.95381036, 3.36734826, 4.00347878, 12.75936071, 3.78643598, 1.99998784, 5.41689723, 7.9756991 ], + // [ 4.42875187, 4.08943809, 4.60244914, 3.78643598, 12.49158813, 6.69560056, 3.87176913, 5.5028702 ], + // [ 4.14710712, 4.23308448, 4.24468227, 1.99998784, 6.69560056, 7.66015758, 4.21792513, 4.53489207], + // [-0.42719386, 4.50030176, 3.62945751, 5.41689723, 3.87176913, 4.21792513, 9.07079513, 2.64352781], + // [ 6.1366026 , 3.92552912, 6.54033601, 7.9756991 , 5.5028702 , 4.53489207, 2.64352781, 8.92801728]] + double[][] X7 = { + {184, 150, 140, 194, 192, 153, 91, 211}, + {150, 248, 203, 198, 216, 187, 171, 214}, + {140, 203, 234, 212, 223, 185, 165, 237}, + {194, 198, 212, 326, 228, 177, 190, 287}, + {192, 216, 223, 228, 318, 239, 180, 262}, + {153, 187, 185, 177, 239, 199, 152, 209}, + { 91, 171, 165, 190, 180, 152, 185, 170}, + {211, 214, 237, 287, 262, 209, 170, 297} + }; + X = X7; + break; + } + + assert X != null; + + // write the input matrix and strategy for matrix sqrt function to dml script + writeInputMatrixWithMTD("X", X, true); + + // run the test dml script + runTest(true, false, null, -1); + + // read the result matrix from the dml script output Y + HashMap actual_Y = readDMLMatrixFromOutputDir("Y"); + + //System.out.println("This is the actual Y: " + actual_Y); + + // create a HashMap with Matrix Values from the input matrix X to compare to the received output matrix + HashMap expected_Y = new HashMap<>(); + for (int r = 0; r < X.length; r++) { + for (int c = 0; c < X[0].length; c++) { + expected_Y.put(new MatrixValue.CellIndex(r + 1, c + 1), X[r][c]); + } + } + + // compare the expected matrix (the input matrix X) with the received output matrix Y, which should be the (SQRT_MATRIX(X))^2 = X again + TestUtils.compareMatrices(expected_Y, actual_Y, eps, "Expected-DML", "Actual-DML"); + } + finally { + resetExecMode(platformOld); + } + } +} diff --git a/src/test/scripts/functions/builtin/SQRTMatrix.dml b/src/test/scripts/functions/builtin/SQRTMatrix.dml new file mode 100644 index 00000000000..0a103505bb0 --- /dev/null +++ b/src/test/scripts/functions/builtin/SQRTMatrix.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +# DML script to test the Square Root Operator for matrices +# Result should be correct, if the result * result == input + +X = read($1) +S = $2 + +if (S == "COMMON") { + A = sqrt_matrix_java(X) +} else if (S == "DML") { + A = matrixSqrt(X) +} else { + stop("Error: Unknown strategy for matrix square root.") +} +Y = A %*% A + +write (Y, $3);