Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions scripts/builtin/matrixSqrt.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

matrixSqrt = function(
Matrix[Double] X
)return(
Matrix[Double] sqrt_x
){
N = nrow(X);
D = ncol(X);

#check that matrix is square
if (D != N){
stop("matrixSqrt Input Error: matrix not square!")
}

# Any non singualar square matrix has a square root
isDiag = isDiagonal(X)
if(isDiag) {
sqrt_x = sqrtDiagMatrix(X);
} else {
[eValues, eVectors] = eigen(X);

hasNonNegativeEigenValues = TRUE
l = length(eValues)
for (i in 1:l){
gtZero = as.scalar(eValues[i]) >= 0.0;
hasNonNegativeEigenValues = gtZero & hasNonNegativeEigenValues;
}

if(!hasNonNegativeEigenValues) {
stop("matrixSqrt exec Error: matrix has imaginary square root");
}

X_t = t(X);
isSymmetric = TRUE;

for (i in 1:N) {
for (j in 1:D) {
same = as.scalar(X[i, j]) == as.scalar(X_t[i, j]);
isSymmetric = isSymmetric & same;
}
}

allEigenValuesUnique = length(eValues) == length(unique(eValues));

if(allEigenValuesUnique | isSymmetric) {
# calculate X = VDV^(-1) -> S = sqrt(D) -> sqrt_x = VSV^(-1)
sqrtD = sqrtDiagMatrix(diag(eValues));
V_Inv = inv(eVectors);
sqrt_x = eVectors %*% sqrtD %*% V_Inv;
} else {
#formular: (Denman–Beavers iteration)
Y = X
#identity matrix
Z = diag(matrix(1.0, rows=N, cols=1))

for (x in 1:100) {
Y_new = (1 / 2) * (Y + inv(Z))
Z_new = (1 / 2) * (Z + inv(Y))
Y = Y_new
Z = Z_new
}
sqrt_x = Y
}
}
}

# assumes square and diagonal matrix
sqrtDiagMatrix = function(
Matrix[Double] X
)return(
Matrix[Double] sqrt_x
){
N = nrow(X);

#check if identity matrix
is_identity = TRUE;
for (i in 1:N) {
is_idElement = as.scalar(X[i, i]) == 1.0;
is_identity = is_identity & is_idElement;
}

if(is_identity) {
sqrt_x = X;
} else {
sqrt_x = matrix(0, rows=N, cols=N);
tmp = 0
for (i in 1:N) {
#workaround needed to access variable for it to be initialized
tmp = tmp + i
value = X[i, i];
sqrt_x[i, i] = sqrt(value);
}
}
}

isDiagonal = function (
Matrix[Double] X
)return(
boolean diagonal
){
N = nrow(X);
D = ncol(X);
noCells = N * D;

diag = diag(diag(X));
compare = X == diag;
sameCells = sum(compare);

#all cells should be the same to be diagonal
diagonal = noCells == sameCells;
}
2 changes: 2 additions & 0 deletions src/main/java/org/apache/sysds/common/Builtins.java
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ public enum Builtins {
STEPLM("steplm",true, ReturnType.MULTI_RETURN),
STFT("stft", false, ReturnType.MULTI_RETURN),
SQRT("sqrt", false),
SQRT_MATRIX("matrixSqrt", true),
SQRT_MATRIX_JAVA("sqrt_matrix_java", false, ReturnType.SINGLE_RETURN),
SUM("sum", false),
SVD("svd", false, ReturnType.MULTI_RETURN),
TABLE("table", "ctable", false),
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/common/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ public enum OpOp1 {
CUMSUMPROD, DETECTSCHEMA, COLNAMES, EIGEN, EXISTS, EXP, FLOOR, INVERSE,
IQM, ISNA, ISNAN, ISINF, LENGTH, LINEAGE, LOG, NCOL, NOT, NROW,
MEDIAN, PREFETCH, PRINT, ROUND, SIN, SINH, SIGN, SOFTMAX, SQRT, STOP, _EVICT,
SVD, TAN, TANH, TYPEOF, TRIGREMOTE,
SVD, TAN, TANH, TYPEOF, TRIGREMOTE, SQRT_MATRIX_JAVA,
//fused ML-specific operators for performance
SPROP, //sample proportion: P * (1 - P)
SIGMOID, //sigmoid function: 1 / (1 + exp(-X))
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/hops/UnaryOp.java
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ && getInput().get(0).getParent().size()==1 ) //unary is only parent

//ensure cp exec type for single-node operations
if( _op == OpOp1.PRINT || _op == OpOp1.ASSERT || _op == OpOp1.STOP || _op == OpOp1.TYPEOF
|| _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == OpOp1.CHOLESKY || _op == OpOp1.SVD
|| _op == OpOp1.INVERSE || _op == OpOp1.EIGEN || _op == OpOp1.CHOLESKY || _op == OpOp1.SVD || _op == OpOp1.SQRT_MATRIX_JAVA
|| getInput().get(0).getDataType() == DataType.LIST || isMetadataOperation() )
{
_etype = ExecType.CP;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1759,6 +1759,19 @@ && isConstant(in[2])
output.setDimensions(in.getDim1(), in.getDim2());
output.setBlocksize(in.getBlocksize());
break;

case SQRT_MATRIX_JAVA:

checkNumParameters(1);
checkMatrixParam(getFirstExpr());
output.setDataType(DataType.MATRIX);
output.setValueType(ValueType.FP64);
Identifier sqrt = getFirstExpr().getOutput();
if(sqrt.dimsKnown() && sqrt.getDim1() != sqrt.getDim2())
raiseValidateError("Input to sqrtMatrix() must be square matrix -- given: a " + sqrt.getDim1() + "x" + sqrt.getDim2() + " matrix.", conditional);
output.setDimensions( sqrt.getDim1(), sqrt.getDim2());
output.setBlocksize( sqrt.getBlocksize());
break;

case CHOLESKY:
{
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/apache/sysds/parser/DMLTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -2749,6 +2749,7 @@ else if ( in.length == 2 )
break;

case INVERSE:
case SQRT_MATRIX_JAVA:
case CHOLESKY:
case TYPEOF:
case DETECTSCHEMA:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ public class CPInstructionParser extends InstructionParser {
String2CPInstructionType.put( "ucummax", CPType.Unary);
String2CPInstructionType.put( "stop" , CPType.Unary);
String2CPInstructionType.put( "inverse", CPType.Unary);
String2CPInstructionType.put( "sqrt_matrix_java", CPType.Unary);
String2CPInstructionType.put( "cholesky",CPType.Unary);
String2CPInstructionType.put( "sprop", CPType.Unary);
String2CPInstructionType.put( "sigmoid", CPType.Unary);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ private LibCommonsMath() {
}

public static boolean isSupportedUnaryOperation( String opcode ) {
return ( opcode.equals("inverse") || opcode.equals("cholesky") );
return ( opcode.equals("inverse") || opcode.equals("cholesky") || opcode.equals("sqrt_matrix_java") );
}

public static boolean isSupportedMultiReturnOperation( String opcode ) {
Expand Down Expand Up @@ -111,6 +111,8 @@ public static MatrixBlock unaryOperations(MatrixBlock inj, String opcode) {
return computeMatrixInverse(matrixInput);
else if (opcode.equals("cholesky"))
return computeCholesky(matrixInput);
else if (opcode.equals("sqrt_matrix_java"))
return computeSqrt(inj);
return null;
}

Expand Down Expand Up @@ -512,7 +514,19 @@ private static MatrixBlock[] computeSvd(MatrixBlock in) {

return new MatrixBlock[] { U, Sigma, V };
}


/**
* Computes the square root of a matrix Calls Apache Commons Math EigenDecomposition.
*
* @param in Input matrix
* @return matrix block
*/
private static MatrixBlock computeSqrt(MatrixBlock in) {
Array2DRowRealMatrix matrixInput = DataConverter.convertToArray2DRowRealMatrix(in);
EigenDecomposition ed = new EigenDecomposition(matrixInput);
return DataConverter.convertToMatrixBlock(ed.getSquareRoot());
}

/**
* Function to compute matrix inverse via matrix decomposition.
*
Expand Down
Loading
Loading