diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java b/src/main/java/org/apache/sysds/conf/DMLConfig.java index dd4d3b2457f..fd34fa4439c 100644 --- a/src/main/java/org/apache/sysds/conf/DMLConfig.java +++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java @@ -201,7 +201,7 @@ public class DMLConfig _defaultVals.put(FLOATING_POINT_PRECISION, "double" ); _defaultVals.put(USE_SSL_FEDERATED_COMMUNICATION, "false"); _defaultVals.put(DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT, "10"); - _defaultVals.put(FEDERATED_TIMEOUT, "-1"); + _defaultVals.put(FEDERATED_TIMEOUT, "86400"); // default 1 day compute timeout. _defaultVals.put(FEDERATED_PLANNER, FederatedPlanner.RUNTIME.name()); _defaultVals.put(FEDERATED_PAR_CONN, "-1"); // vcores _defaultVals.put(FEDERATED_PAR_INST, "-1"); // vcores diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java index cc6310d70bc..51cc34fa9e4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java @@ -65,7 +65,7 @@ private List coCodeBruteForce(List workSet = new ArrayList<>(inputColumns.size()); k = k <= 0 ? InfrastructureAnalyzer.getLocalParallelism() : k; - final ExecutorService pool = CommonThreadPool.get(k); + final ExecutorService pool = k > 1 ? CommonThreadPool.get(k) : null; try { for(int i = 0; i < inputColumns.size(); i++) { CompressedSizeInfoColGroup g = inputColumns.get(i); @@ -183,7 +183,8 @@ else if((newCostIfJoined < changeInCost || return ret; } finally { - pool.shutdown(); + if(pool != null) + pool.shutdown(); } } @@ -195,8 +196,12 @@ protected void parallelFirstCombine(List workSet, ExecutorService po for(int j = i + 1; j < size; j++) tasks.add(new CombineTask(workSet.get(i), workSet.get(j))); - for(Future t : pool.invokeAll(tasks)) - t.get(); + if(pool != null) + for(Future t : pool.invokeAll(tasks)) + t.get(); + else + for(CombineTask t: tasks) + t.call(); } catch(Exception e) { throw new DMLCompressionException("Failed parallelize first level all join all", e); diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java index 85bb00951f8..8cfe4639e45 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java @@ -418,8 +418,38 @@ private static void divideByNumberOfCellsForMeanAll(CompressedMatrixBlock m1, Ma private static void decompressingAggregate(CompressedMatrixBlock m1, MatrixBlock ret, AggregateUnaryOperator op, MatrixIndexes indexesIn, boolean inCP) throws Exception { - List> rtasks = generateUnaryAggregateOverlappingFutures(m1, ret, op); - reduceFutures(rtasks, ret, op, true); + if(op.getNumThreads() > 1){ + + List> rtasks = generateUnaryAggregateOverlappingFutures(m1, ret, op); + reduceFutures(rtasks, ret, op, true); + } + else{ + final int nCol = m1.getNumColumns(); + final int nRow = m1.getNumRows(); + final List groups = m1.getColGroups(); + final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups); + + final UAOverlappingTask t; + if(shouldFilter) { + final double[] constV = new double[nCol]; + final List filteredGroups = CLALibUtils.filterGroups(groups, constV); + final AColGroup cRet = ColGroupConst.create(constV); + filteredGroups.add(cRet); + t = new UAOverlappingTask(filteredGroups, ret, 0, nRow, op, nCol); + } + else { + t = new UAOverlappingTask(groups, ret, 0, nRow, op, nCol); + } + if(op.indexFn instanceof ReduceAll) + ret.set(0,0,t.call().get(0,0)); + else if(op.indexFn instanceof ReduceRow) { + final boolean isPlus = op.aggOp.increOp.fn instanceof Mean || op.aggOp.increOp.fn instanceof KahanFunction; + final BinaryOperator bop = isPlus ? new BinaryOperator(Plus.getPlusFnObject()) : op. aggOp.increOp; + LibMatrixBincell.bincellOpInPlace(ret, t.call(), bop); + } + else // reduce cols just get the tasks done. + t.call(); + } } private static void reduceFutures(List> futures, MatrixBlock ret, AggregateUnaryOperator op, diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java index 5588a538aa6..ce27781181c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java @@ -108,16 +108,35 @@ public static MatrixBlock scalarOperations(ScalarOperator sop, CompressedMatrixB } private static MatrixBlock fusedScalarAndDecompress(CompressedMatrixBlock in, ScalarOperator sop) { + if(sop.getNumThreads() <= 1) + return singleThreadFusedScalarAndDecompress(in, sop); + return parallelFusedScalarAndDecompress(in, sop) ; + } + + private static MatrixBlock singleThreadFusedScalarAndDecompress(CompressedMatrixBlock in, ScalarOperator sop){ + final int nRow = in.getNumRows(); + final int nCol = in.getNumColumns(); + final MatrixBlock out = new MatrixBlock(nRow, nCol, false); + out.allocateDenseBlock(); + final DenseBlock db = out.getDenseBlock(); + final List groups = in.getColGroups(); + long nnz = fusedDecompressAndScalar(groups, nCol, 0, nRow, db, sop); + out.setNonZeros(nnz); + out.examSparsity(true); + return out; + } + + private static MatrixBlock parallelFusedScalarAndDecompress(CompressedMatrixBlock in, ScalarOperator sop) { int k = sop.getNumThreads(); ExecutorService pool = CommonThreadPool.get(k); try { - final int nRow = in.getNumRows(); + final int nRow = in.getNumRows(); final int nCol = in.getNumColumns(); final MatrixBlock out = new MatrixBlock(nRow, nCol, false); final List groups = in.getColGroups(); out.allocateDenseBlock(); final DenseBlock db = out.getDenseBlock(); - final int blkz = Math.max((int)(Math.ceil((double)nRow / k)), 256); + final int blkz = Math.max((int) (Math.ceil((double) nRow / k)), 256); final List> tasks = new ArrayList<>(); for(int i = 0; i < nRow; i += blkz) { final int start = i; @@ -138,9 +157,6 @@ private static MatrixBlock fusedScalarAndDecompress(CompressedMatrixBlock in, Sc finally { pool.shutdown(); } - - // MatrixBlock m1d = m1.decompress(sop.getNumThreads()); - // return m1d.scalarOperations(sop, result); } private static long fusedDecompressAndScalar(final List groups, int nCol, int start, int end, diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java index e95375cc75a..339aabe9768 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java @@ -503,7 +503,8 @@ public static List getWorkerDataObjects() { return new ArrayList<>(workerDataObjects.values()); } - public static void addEvent(EventModel event) { + public synchronized static void addEvent(EventModel event) { + // synchronized, because multiple requests can be handled concurrently workerEvents.add(event); } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java index ceaf61c225c..ce21c79825b 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java @@ -651,8 +651,7 @@ private FederatedResponse execUDF(FederatedRequest request, ExecutionContextMap // get function and input parameters try { FederatedUDF udf = (FederatedUDF) request.getParam(0); - if(LOG.isDebugEnabled()) - LOG.debug(udf); + LOG.debug(udf); eventStage.operation = udf.getClass().getSimpleName(); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java index 2574c4f1759..a6ae6d55424 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java @@ -26,7 +26,9 @@ import java.util.Map.Entry; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -34,6 +36,7 @@ import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.hops.fedplanner.FTypes.AlignType; import org.apache.sysds.hops.fedplanner.FTypes.FType; import org.apache.sysds.lops.RightIndex; @@ -637,11 +640,25 @@ public long getMaxIndexInRange(int dim) { * @param forEachFunction function to execute for each pair */ public void forEachParallel(BiFunction forEachFunction) { - ExecutorService pool = CommonThreadPool.get(_fedMap.size()); + ExecutorService pool = Executors.newFixedThreadPool(_fedMap.size()); ArrayList mappingTasks = new ArrayList<>(); for(Pair fedMap : _fedMap) mappingTasks.add(new MappingTask(fedMap.getKey(), fedMap.getValue(), forEachFunction, _ID)); - CommonThreadPool.invokeAndShutdown(pool, mappingTasks); + + try { + for(Future t:pool.invokeAll(mappingTasks, ConfigurationManager.getFederatedTimeout(), TimeUnit.SECONDS)){ + if(!t.isDone()) + throw new RuntimeException("Timeout"); + else if(t.isCancelled()) + throw new RuntimeException("Failed"); + } + } + catch(InterruptedException e) { + throw new RuntimeException(e); + } + finally{ + pool.shutdown(); + } } /** @@ -655,15 +672,25 @@ public void forEachParallel(BiFunction forE * @return the new FederationMap */ public FederationMap mapParallel(long newVarID, BiFunction mappingFunction) { - ExecutorService pool = CommonThreadPool.get(_fedMap.size()); - + ExecutorService pool = Executors.newFixedThreadPool(_fedMap.size()); FederationMap fedMapCopy = copyWithNewID(_ID); ArrayList mappingTasks = new ArrayList<>(); for(Pair fedMap : fedMapCopy._fedMap) mappingTasks.add(new MappingTask(fedMap.getKey(), fedMap.getValue(), mappingFunction, newVarID)); - CommonThreadPool.invokeAndShutdown(pool, mappingTasks); - fedMapCopy._ID = newVarID; - return fedMapCopy; + try{ + for(Future t : pool.invokeAll(mappingTasks, ConfigurationManager.getFederatedTimeout(), TimeUnit.SECONDS)){ + if(!t.isDone()) + throw new RuntimeException("Timeout"); + else if(t.isCancelled()){ + throw new RuntimeException("Failed"); + } + } + fedMapCopy._ID = newVarID; + return fedMapCopy; + } + catch(Exception e){ + throw new RuntimeException(e); + } } public FederationMap filter(IndexRange ixrange) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 63e8fe1672f..be3bf9de11c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -51,6 +51,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -61,6 +62,7 @@ import org.apache.spark.util.LongAccumulator; import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.hops.recompile.Recompiler; import org.apache.sysds.parser.Statement.FederatedPSScheme; import org.apache.sysds.parser.Statement.PSFrequency; @@ -241,13 +243,16 @@ model, aggServiceEC, getValFunction(), getNumBatchesPerEpoch(runtimeBalancing, r try { // Launch the worker threads and wait for completion - for (Future ret : es.invokeAll(threads)) - ret.get(); //error handling + for (Future ret : es.invokeAll(threads, ConfigurationManager.getFederatedTimeout(), TimeUnit.SECONDS)){ + if(!ret.isDone()) + throw new RuntimeException("Failed federated execution"); + // ret.get(); //error handling + } // Fetch the final model from ps ec.setVariable(output.getName(), ps.getResult()); if (DMLScript.STATISTICS) ParamServStatistics.accExecutionTime((long) ParamServStatistics.getExecutionTimer().stop()); - } catch (InterruptedException | ExecutionException e) { + } catch (Exception e) { throw new DMLRuntimeException("ParamservBuiltinCPInstruction: unknown error: ", e); } finally { es.shutdownNow(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java index 69e0361ee7e..4d7a63cf90b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java @@ -35,6 +35,7 @@ import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.hops.fedplanner.FTypes; import org.apache.sysds.hops.fedplanner.FTypes.FType; import org.apache.sysds.lops.PickByCount; @@ -47,10 +48,10 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse.ResponseType; -import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; +import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.cp.Data; @@ -175,6 +176,7 @@ public void processInstruction(ExecutionContext ec) { try { FederatedResponse response = responseFuture.get(); MultiColumnEncoder encoder = (MultiColumnEncoder) response.getData()[0]; + // merge this encoder into a composite encoder synchronized(finalGlobalEncoder) { finalGlobalEncoder.mergeAt(encoder, columnOffset, (int) (range.getBeginDims()[0] + 1)); @@ -378,24 +380,30 @@ public ExecuteFrameEncoder(long input, long output, MultiColumnEncoder encoder) @Override public FederatedResponse execute(ExecutionContext ec, Data... data) { - FrameBlock fb = ((FrameObject) data[0]).acquireReadAndRelease(); - - // offset is applied on the Worker to shift the local encoders to their respective column - _encoder.applyColumnOffset(); - // apply transformation - //MatrixBlock mbout = _encoder.apply(fb, OptimizerUtils.getTransformNumThreads()); - // FIXME: Enabling multithreading intermittently hangs - MatrixBlock mbout = _encoder.apply(fb, 1); - - // create output matrix object - MatrixObject mo = ExecutionContext.createMatrixObject(mbout); - - // add it to the list of variables - ec.setVariable(String.valueOf(_outputID), mo); + try{ - // return id handle - return new FederatedResponse( - ResponseType.SUCCESS_EMPTY, mbout.getNonZeros()); + FrameBlock fb = ((FrameObject) data[0]).acquireReadAndRelease(); + + // offset is applied on the Worker to shift the local encoders to their respective column + _encoder.applyColumnOffset(); + // apply transformation + MatrixBlock mbout = _encoder.apply(fb, OptimizerUtils.getTransformNumThreads()); + // FIXME: Enabling multithreading intermittently hangs + // MatrixBlock mbout = _encoder.apply(fb, 1); + + // create output matrix object + MatrixObject mo = ExecutionContext.createMatrixObject(mbout); + + // add it to the list of variables + ec.setVariable(String.valueOf(_outputID), mo); + + // return id handle + return new FederatedResponse( + ResponseType.SUCCESS_EMPTY, mbout.getNonZeros()); + } + catch(Exception e){ + return new FederatedResponse(ResponseType.ERROR); + } } @Override diff --git a/src/test/config/SystemDS-MultiTenant-config.xml b/src/test/config/SystemDS-MultiTenant-config.xml index ad2bcf0ee5f..321fcc0b282 100644 --- a/src/test/config/SystemDS-MultiTenant-config.xml +++ b/src/test/config/SystemDS-MultiTenant-config.xml @@ -21,6 +21,6 @@ 30 - 128 + 16 true diff --git a/src/test/config/SystemDS-config.xml b/src/test/config/SystemDS-config.xml index a6f5ba525f7..a899f5c71c6 100644 --- a/src/test/config/SystemDS-config.xml +++ b/src/test/config/SystemDS-config.xml @@ -23,5 +23,5 @@ 2 - 128 + 16 diff --git a/src/test/java/org/apache/sysds/test/functions/codegen/RowVectorComparisonTest.java b/src/test/java/org/apache/sysds/test/functions/codegen/RowVectorComparisonTest.java index 3f287bdc07b..b5d5a77474d 100644 --- a/src/test/java/org/apache/sysds/test/functions/codegen/RowVectorComparisonTest.java +++ b/src/test/java/org/apache/sysds/test/functions/codegen/RowVectorComparisonTest.java @@ -128,6 +128,7 @@ private void testCodegenIntegration( String testname, boolean rewrites, ExecType { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; ExecMode platformOld = setExecMode(instType); + setOutputBuffering(true); try { diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java index a79a24a3b5b..f8acc4623a7 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java @@ -79,6 +79,7 @@ public void testTokenizeFullDenseFrameCP() { private void runAggregateOperationTest(ExecMode execMode) { setExecMode(execMode); + setOutputBuffering(true); String TEST_NAME = TEST_NAME1;