From ce231510408ea00573d3632e54fd9b83a087d336 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 09:58:05 +0100 Subject: [PATCH 1/4] [MINOR] Cocode intelligent parallel allocation --- .../sysds/runtime/compress/cocode/CoCodeGreedy.java | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java index cc6310d70bc..51cc34fa9e4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java @@ -65,7 +65,7 @@ private List coCodeBruteForce(List workSet = new ArrayList<>(inputColumns.size()); k = k <= 0 ? InfrastructureAnalyzer.getLocalParallelism() : k; - final ExecutorService pool = CommonThreadPool.get(k); + final ExecutorService pool = k > 1 ? CommonThreadPool.get(k) : null; try { for(int i = 0; i < inputColumns.size(); i++) { CompressedSizeInfoColGroup g = inputColumns.get(i); @@ -183,7 +183,8 @@ else if((newCostIfJoined < changeInCost || return ret; } finally { - pool.shutdown(); + if(pool != null) + pool.shutdown(); } } @@ -195,8 +196,12 @@ protected void parallelFirstCombine(List workSet, ExecutorService po for(int j = i + 1; j < size; j++) tasks.add(new CombineTask(workSet.get(i), workSet.get(j))); - for(Future t : pool.invokeAll(tasks)) - t.get(); + if(pool != null) + for(Future t : pool.invokeAll(tasks)) + t.get(); + else + for(CombineTask t: tasks) + t.call(); } catch(Exception e) { throw new DMLCompressionException("Failed parallelize first level all join all", e); From 75874c5c83a4dfe01c238f2c0c3a883a2fc86878 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 09:59:07 +0100 Subject: [PATCH 2/4] [MINOR] Fix singlethread CLA decompressingAggregate --- .../runtime/compress/lib/CLALibCompAgg.java | 34 +++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java index 85bb00951f8..8cfe4639e45 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java @@ -418,8 +418,38 @@ private static void divideByNumberOfCellsForMeanAll(CompressedMatrixBlock m1, Ma private static void decompressingAggregate(CompressedMatrixBlock m1, MatrixBlock ret, AggregateUnaryOperator op, MatrixIndexes indexesIn, boolean inCP) throws Exception { - List> rtasks = generateUnaryAggregateOverlappingFutures(m1, ret, op); - reduceFutures(rtasks, ret, op, true); + if(op.getNumThreads() > 1){ + + List> rtasks = generateUnaryAggregateOverlappingFutures(m1, ret, op); + reduceFutures(rtasks, ret, op, true); + } + else{ + final int nCol = m1.getNumColumns(); + final int nRow = m1.getNumRows(); + final List groups = m1.getColGroups(); + final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups); + + final UAOverlappingTask t; + if(shouldFilter) { + final double[] constV = new double[nCol]; + final List filteredGroups = CLALibUtils.filterGroups(groups, constV); + final AColGroup cRet = ColGroupConst.create(constV); + filteredGroups.add(cRet); + t = new UAOverlappingTask(filteredGroups, ret, 0, nRow, op, nCol); + } + else { + t = new UAOverlappingTask(groups, ret, 0, nRow, op, nCol); + } + if(op.indexFn instanceof ReduceAll) + ret.set(0,0,t.call().get(0,0)); + else if(op.indexFn instanceof ReduceRow) { + final boolean isPlus = op.aggOp.increOp.fn instanceof Mean || op.aggOp.increOp.fn instanceof KahanFunction; + final BinaryOperator bop = isPlus ? new BinaryOperator(Plus.getPlusFnObject()) : op. aggOp.increOp; + LibMatrixBincell.bincellOpInPlace(ret, t.call(), bop); + } + else // reduce cols just get the tasks done. + t.call(); + } } private static void reduceFutures(List> futures, MatrixBlock ret, AggregateUnaryOperator op, From cc561444d693d5d77744af69b6db6ff09ae2e6c7 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 10:01:03 +0100 Subject: [PATCH 3/4] [MINOR] Add singlethread fused scalar and decompress --- .../runtime/compress/lib/CLALibScalar.java | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java index 5588a538aa6..ce27781181c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibScalar.java @@ -108,16 +108,35 @@ public static MatrixBlock scalarOperations(ScalarOperator sop, CompressedMatrixB } private static MatrixBlock fusedScalarAndDecompress(CompressedMatrixBlock in, ScalarOperator sop) { + if(sop.getNumThreads() <= 1) + return singleThreadFusedScalarAndDecompress(in, sop); + return parallelFusedScalarAndDecompress(in, sop) ; + } + + private static MatrixBlock singleThreadFusedScalarAndDecompress(CompressedMatrixBlock in, ScalarOperator sop){ + final int nRow = in.getNumRows(); + final int nCol = in.getNumColumns(); + final MatrixBlock out = new MatrixBlock(nRow, nCol, false); + out.allocateDenseBlock(); + final DenseBlock db = out.getDenseBlock(); + final List groups = in.getColGroups(); + long nnz = fusedDecompressAndScalar(groups, nCol, 0, nRow, db, sop); + out.setNonZeros(nnz); + out.examSparsity(true); + return out; + } + + private static MatrixBlock parallelFusedScalarAndDecompress(CompressedMatrixBlock in, ScalarOperator sop) { int k = sop.getNumThreads(); ExecutorService pool = CommonThreadPool.get(k); try { - final int nRow = in.getNumRows(); + final int nRow = in.getNumRows(); final int nCol = in.getNumColumns(); final MatrixBlock out = new MatrixBlock(nRow, nCol, false); final List groups = in.getColGroups(); out.allocateDenseBlock(); final DenseBlock db = out.getDenseBlock(); - final int blkz = Math.max((int)(Math.ceil((double)nRow / k)), 256); + final int blkz = Math.max((int) (Math.ceil((double) nRow / k)), 256); final List> tasks = new ArrayList<>(); for(int i = 0; i < nRow; i += blkz) { final int start = i; @@ -138,9 +157,6 @@ private static MatrixBlock fusedScalarAndDecompress(CompressedMatrixBlock in, Sc finally { pool.shutdown(); } - - // MatrixBlock m1d = m1.decompress(sop.getNumThreads()); - // return m1d.scalarOperations(sop, result); } private static long fusedDecompressAndScalar(final List groups, int nCol, int start, int end, From 3acbab4105d801f4697315447541ce8031743fb0 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 10:07:16 +0100 Subject: [PATCH 4/4] [MINOR] Federated Timeout This commit reduce the timeout for federated tests, and enforce the timeout on federated requests. Previously we had some test cases that would infinitely run, and therefore we would not be able to decipher the log messages (because nothing would be written). This commit change it by enforcing a strict 16 seconds execution time for a single federated requests and a 1 day timeout for a default federated requests. Previously some operations did use the federated timeout. However, it was not enforced in critical places. --- .../java/org/apache/sysds/conf/DMLConfig.java | 2 +- .../federated/FederatedStatistics.java | 3 +- .../federated/FederatedWorkerHandler.java | 3 +- .../federated/FederationMap.java | 41 ++++++++++++++--- .../cp/ParamservBuiltinCPInstruction.java | 11 +++-- ...urnParameterizedBuiltinFEDInstruction.java | 44 +++++++++++-------- .../config/SystemDS-MultiTenant-config.xml | 2 +- src/test/config/SystemDS-config.xml | 2 +- .../codegen/RowVectorComparisonTest.java | 1 + .../part3/FederatedTokenizeTest.java | 1 + 10 files changed, 76 insertions(+), 34 deletions(-) diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java b/src/main/java/org/apache/sysds/conf/DMLConfig.java index dd4d3b2457f..fd34fa4439c 100644 --- a/src/main/java/org/apache/sysds/conf/DMLConfig.java +++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java @@ -201,7 +201,7 @@ public class DMLConfig _defaultVals.put(FLOATING_POINT_PRECISION, "double" ); _defaultVals.put(USE_SSL_FEDERATED_COMMUNICATION, "false"); _defaultVals.put(DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT, "10"); - _defaultVals.put(FEDERATED_TIMEOUT, "-1"); + _defaultVals.put(FEDERATED_TIMEOUT, "86400"); // default 1 day compute timeout. _defaultVals.put(FEDERATED_PLANNER, FederatedPlanner.RUNTIME.name()); _defaultVals.put(FEDERATED_PAR_CONN, "-1"); // vcores _defaultVals.put(FEDERATED_PAR_INST, "-1"); // vcores diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java index e95375cc75a..339aabe9768 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.java @@ -503,7 +503,8 @@ public static List getWorkerDataObjects() { return new ArrayList<>(workerDataObjects.values()); } - public static void addEvent(EventModel event) { + public synchronized static void addEvent(EventModel event) { + // synchronized, because multiple requests can be handled concurrently workerEvents.add(event); } diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java index ceaf61c225c..ce21c79825b 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java @@ -651,8 +651,7 @@ private FederatedResponse execUDF(FederatedRequest request, ExecutionContextMap // get function and input parameters try { FederatedUDF udf = (FederatedUDF) request.getParam(0); - if(LOG.isDebugEnabled()) - LOG.debug(udf); + LOG.debug(udf); eventStage.operation = udf.getClass().getSimpleName(); diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java index 2574c4f1759..a6ae6d55424 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java @@ -26,7 +26,9 @@ import java.util.Map.Entry; import java.util.concurrent.Callable; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; import java.util.function.BiFunction; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -34,6 +36,7 @@ import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.hops.fedplanner.FTypes.AlignType; import org.apache.sysds.hops.fedplanner.FTypes.FType; import org.apache.sysds.lops.RightIndex; @@ -637,11 +640,25 @@ public long getMaxIndexInRange(int dim) { * @param forEachFunction function to execute for each pair */ public void forEachParallel(BiFunction forEachFunction) { - ExecutorService pool = CommonThreadPool.get(_fedMap.size()); + ExecutorService pool = Executors.newFixedThreadPool(_fedMap.size()); ArrayList mappingTasks = new ArrayList<>(); for(Pair fedMap : _fedMap) mappingTasks.add(new MappingTask(fedMap.getKey(), fedMap.getValue(), forEachFunction, _ID)); - CommonThreadPool.invokeAndShutdown(pool, mappingTasks); + + try { + for(Future t:pool.invokeAll(mappingTasks, ConfigurationManager.getFederatedTimeout(), TimeUnit.SECONDS)){ + if(!t.isDone()) + throw new RuntimeException("Timeout"); + else if(t.isCancelled()) + throw new RuntimeException("Failed"); + } + } + catch(InterruptedException e) { + throw new RuntimeException(e); + } + finally{ + pool.shutdown(); + } } /** @@ -655,15 +672,25 @@ public void forEachParallel(BiFunction forE * @return the new FederationMap */ public FederationMap mapParallel(long newVarID, BiFunction mappingFunction) { - ExecutorService pool = CommonThreadPool.get(_fedMap.size()); - + ExecutorService pool = Executors.newFixedThreadPool(_fedMap.size()); FederationMap fedMapCopy = copyWithNewID(_ID); ArrayList mappingTasks = new ArrayList<>(); for(Pair fedMap : fedMapCopy._fedMap) mappingTasks.add(new MappingTask(fedMap.getKey(), fedMap.getValue(), mappingFunction, newVarID)); - CommonThreadPool.invokeAndShutdown(pool, mappingTasks); - fedMapCopy._ID = newVarID; - return fedMapCopy; + try{ + for(Future t : pool.invokeAll(mappingTasks, ConfigurationManager.getFederatedTimeout(), TimeUnit.SECONDS)){ + if(!t.isDone()) + throw new RuntimeException("Timeout"); + else if(t.isCancelled()){ + throw new RuntimeException("Failed"); + } + } + fedMapCopy._ID = newVarID; + return fedMapCopy; + } + catch(Exception e){ + throw new RuntimeException(e); + } } public FederationMap filter(IndexRange ixrange) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java index 63e8fe1672f..be3bf9de11c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParamservBuiltinCPInstruction.java @@ -51,6 +51,7 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -61,6 +62,7 @@ import org.apache.spark.util.LongAccumulator; import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.hops.recompile.Recompiler; import org.apache.sysds.parser.Statement.FederatedPSScheme; import org.apache.sysds.parser.Statement.PSFrequency; @@ -241,13 +243,16 @@ model, aggServiceEC, getValFunction(), getNumBatchesPerEpoch(runtimeBalancing, r try { // Launch the worker threads and wait for completion - for (Future ret : es.invokeAll(threads)) - ret.get(); //error handling + for (Future ret : es.invokeAll(threads, ConfigurationManager.getFederatedTimeout(), TimeUnit.SECONDS)){ + if(!ret.isDone()) + throw new RuntimeException("Failed federated execution"); + // ret.get(); //error handling + } // Fetch the final model from ps ec.setVariable(output.getName(), ps.getResult()); if (DMLScript.STATISTICS) ParamServStatistics.accExecutionTime((long) ParamServStatistics.getExecutionTimer().stop()); - } catch (InterruptedException | ExecutionException e) { + } catch (Exception e) { throw new DMLRuntimeException("ParamservBuiltinCPInstruction: unknown error: ", e); } finally { es.shutdownNow(); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java index 69e0361ee7e..4d7a63cf90b 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java @@ -35,6 +35,7 @@ import org.apache.sysds.common.Types; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.hops.fedplanner.FTypes; import org.apache.sysds.hops.fedplanner.FTypes.FType; import org.apache.sysds.lops.PickByCount; @@ -47,10 +48,10 @@ import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse.ResponseType; -import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; import org.apache.sysds.runtime.controlprogram.federated.FederationMap; import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; +import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; import org.apache.sysds.runtime.instructions.cp.Data; @@ -175,6 +176,7 @@ public void processInstruction(ExecutionContext ec) { try { FederatedResponse response = responseFuture.get(); MultiColumnEncoder encoder = (MultiColumnEncoder) response.getData()[0]; + // merge this encoder into a composite encoder synchronized(finalGlobalEncoder) { finalGlobalEncoder.mergeAt(encoder, columnOffset, (int) (range.getBeginDims()[0] + 1)); @@ -378,24 +380,30 @@ public ExecuteFrameEncoder(long input, long output, MultiColumnEncoder encoder) @Override public FederatedResponse execute(ExecutionContext ec, Data... data) { - FrameBlock fb = ((FrameObject) data[0]).acquireReadAndRelease(); - - // offset is applied on the Worker to shift the local encoders to their respective column - _encoder.applyColumnOffset(); - // apply transformation - //MatrixBlock mbout = _encoder.apply(fb, OptimizerUtils.getTransformNumThreads()); - // FIXME: Enabling multithreading intermittently hangs - MatrixBlock mbout = _encoder.apply(fb, 1); - - // create output matrix object - MatrixObject mo = ExecutionContext.createMatrixObject(mbout); - - // add it to the list of variables - ec.setVariable(String.valueOf(_outputID), mo); + try{ - // return id handle - return new FederatedResponse( - ResponseType.SUCCESS_EMPTY, mbout.getNonZeros()); + FrameBlock fb = ((FrameObject) data[0]).acquireReadAndRelease(); + + // offset is applied on the Worker to shift the local encoders to their respective column + _encoder.applyColumnOffset(); + // apply transformation + MatrixBlock mbout = _encoder.apply(fb, OptimizerUtils.getTransformNumThreads()); + // FIXME: Enabling multithreading intermittently hangs + // MatrixBlock mbout = _encoder.apply(fb, 1); + + // create output matrix object + MatrixObject mo = ExecutionContext.createMatrixObject(mbout); + + // add it to the list of variables + ec.setVariable(String.valueOf(_outputID), mo); + + // return id handle + return new FederatedResponse( + ResponseType.SUCCESS_EMPTY, mbout.getNonZeros()); + } + catch(Exception e){ + return new FederatedResponse(ResponseType.ERROR); + } } @Override diff --git a/src/test/config/SystemDS-MultiTenant-config.xml b/src/test/config/SystemDS-MultiTenant-config.xml index ad2bcf0ee5f..321fcc0b282 100644 --- a/src/test/config/SystemDS-MultiTenant-config.xml +++ b/src/test/config/SystemDS-MultiTenant-config.xml @@ -21,6 +21,6 @@ 30 - 128 + 16 true diff --git a/src/test/config/SystemDS-config.xml b/src/test/config/SystemDS-config.xml index a6f5ba525f7..a899f5c71c6 100644 --- a/src/test/config/SystemDS-config.xml +++ b/src/test/config/SystemDS-config.xml @@ -23,5 +23,5 @@ 2 - 128 + 16 diff --git a/src/test/java/org/apache/sysds/test/functions/codegen/RowVectorComparisonTest.java b/src/test/java/org/apache/sysds/test/functions/codegen/RowVectorComparisonTest.java index 3f287bdc07b..b5d5a77474d 100644 --- a/src/test/java/org/apache/sysds/test/functions/codegen/RowVectorComparisonTest.java +++ b/src/test/java/org/apache/sysds/test/functions/codegen/RowVectorComparisonTest.java @@ -128,6 +128,7 @@ private void testCodegenIntegration( String testname, boolean rewrites, ExecType { boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION; ExecMode platformOld = setExecMode(instType); + setOutputBuffering(true); try { diff --git a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java index a79a24a3b5b..f8acc4623a7 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/primitives/part3/FederatedTokenizeTest.java @@ -79,6 +79,7 @@ public void testTokenizeFullDenseFrameCP() { private void runAggregateOperationTest(ExecMode execMode) { setExecMode(execMode); + setOutputBuffering(true); String TEST_NAME = TEST_NAME1;