Skip to content

Commit fdfc718

Browse files
committed
Track more compressed-friendly ops in FederatedWorkloadAnalyzer
Extends the federated workload counter so that compression decisions account for additional instruction shapes beyond AggregateBinary. - Pass the right-hand column count to incOverlappingDecompressions so the cost model reflects the actual decompression size rather than counting a single column - Count MMChainCPInstruction as one LMM and one RMM contribution per invocation - Count AggregateUnaryCPInstruction: when reducing columns with a sum/mean operator, treat it as a dict-op (compression-friendly); otherwise count it as a decompression - Minor formatting cleanup in compressRun
1 parent 9a4e2a3 commit fdfc718

1 file changed

Lines changed: 48 additions & 5 deletions

File tree

src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkloadAnalyzer.java

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,18 @@
2727
import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter;
2828
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
2929
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
30+
import org.apache.sysds.runtime.functionobjects.IndexFunction;
31+
import org.apache.sysds.runtime.functionobjects.KahanPlus;
32+
import org.apache.sysds.runtime.functionobjects.Mean;
33+
import org.apache.sysds.runtime.functionobjects.Plus;
34+
import org.apache.sysds.runtime.functionobjects.ReduceCol;
3035
import org.apache.sysds.runtime.instructions.Instruction;
3136
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
37+
import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
3238
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
39+
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
40+
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
41+
import org.apache.sysds.runtime.matrix.operators.Operator;
3342

3443
public class FederatedWorkloadAnalyzer {
3544
protected static final Log LOG = LogFactory.getLog(FederatedWorkloadAnalyzer.class.getName());
@@ -55,7 +64,7 @@ public void incrementWorkload(ExecutionContext ec, long tid, Instruction ins) {
5564
}
5665

5766
public void compressRun(ExecutionContext ec, long tid) {
58-
if(counter >= compressRunFrequency ){
67+
if(counter >= compressRunFrequency) {
5968
counter = 0;
6069
get(tid).forEach((K, V) -> CompressedMatrixBlockFactory.compressAsync(ec, Long.toString(K), V));
6170
}
@@ -68,6 +77,7 @@ private void incrementWorkload(ExecutionContext ec, long tid, ComputationCPInstr
6877
public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap<Long, InstructionTypeCounter> mm,
6978
ComputationCPInstruction cpIns) {
7079
// TODO: Count transitive closure via lineage
80+
// TODO: add more operations
7181
if(cpIns instanceof AggregateBinaryCPInstruction) {
7282
final String n1 = cpIns.input1.getName();
7383
MatrixObject d1 = (MatrixObject) ec.getCacheableData(n1);
@@ -81,15 +91,48 @@ public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap<Long, Instr
8191
if(validSize(r1, c1)) {
8292
getOrMakeCounter(mm, Long.parseLong(n1)).incRMM(c2);
8393
// safety add overlapping decompress for RMM
84-
getOrMakeCounter(mm, Long.parseLong(n1)).incOverlappingDecompressions();
94+
getOrMakeCounter(mm, Long.parseLong(n1)).incOverlappingDecompressions(c2);
8595
counter++;
8696
}
8797
if(validSize(r2, c2)) {
8898
getOrMakeCounter(mm, Long.parseLong(n2)).incLMM(r1);
8999
counter++;
90100
}
91-
92101
}
102+
else if(cpIns instanceof MMChainCPInstruction) {
103+
final String n1 = cpIns.input1.getName();
104+
getOrMakeCounter(mm, Long.parseLong(n1)).incRMM(1);
105+
getOrMakeCounter(mm, Long.parseLong(n1)).incLMM(1);
106+
counter++;
107+
}
108+
else if(cpIns instanceof AggregateUnaryCPInstruction) {
109+
Operator op = cpIns.getOperator();
110+
final String n1 = cpIns.input1.getName();
111+
long id = Long.parseLong(n1);
112+
// MatrixObject d1 = (MatrixObject) ec.getCacheableData(n1);
113+
// int r1 = (int) d1.getDim(0);
114+
// int c1 = (int) d1.getDim(1);
115+
if(op instanceof AggregateUnaryOperator) {
116+
AggregateUnaryOperator aop = (AggregateUnaryOperator) op;
117+
IndexFunction idxF = aop.indexFn;
118+
getOrMakeCounter(mm, id).incDictOps();
119+
if(idxF instanceof ReduceCol) {
120+
if((aop.aggOp.increOp.fn instanceof KahanPlus //
121+
|| aop.aggOp.increOp.fn instanceof Plus //
122+
|| aop.aggOp.increOp.fn instanceof Mean)) {
123+
getOrMakeCounter(mm, id).incDictOps();
124+
}
125+
else {
126+
// increment decompression if row reduce.
127+
getOrMakeCounter(mm, id).incDecompressions();
128+
}
129+
}
130+
else {
131+
getOrMakeCounter(mm, id).incDictOps();
132+
}
133+
}
134+
}
135+
93136
}
94137

95138
private static InstructionTypeCounter getOrMakeCounter(ConcurrentHashMap<Long, InstructionTypeCounter> mm, long id) {
@@ -117,8 +160,8 @@ private static boolean validSize(int nRow, int nCol) {
117160
return nRow > 90 && nRow >= nCol;
118161
}
119162

120-
@Override
121-
public String toString(){
163+
@Override
164+
public String toString() {
122165
StringBuilder sb = new StringBuilder();
123166
sb.append(this.getClass().getSimpleName());
124167
sb.append(" Counter: ");

0 commit comments

Comments
 (0)