2727import org .apache .sysds .runtime .compress .cost .InstructionTypeCounter ;
2828import org .apache .sysds .runtime .controlprogram .caching .MatrixObject ;
2929import org .apache .sysds .runtime .controlprogram .context .ExecutionContext ;
30+ import org .apache .sysds .runtime .functionobjects .IndexFunction ;
31+ import org .apache .sysds .runtime .functionobjects .KahanPlus ;
32+ import org .apache .sysds .runtime .functionobjects .Mean ;
33+ import org .apache .sysds .runtime .functionobjects .Plus ;
34+ import org .apache .sysds .runtime .functionobjects .ReduceCol ;
3035import org .apache .sysds .runtime .instructions .Instruction ;
3136import org .apache .sysds .runtime .instructions .cp .AggregateBinaryCPInstruction ;
37+ import org .apache .sysds .runtime .instructions .cp .AggregateUnaryCPInstruction ;
3238import org .apache .sysds .runtime .instructions .cp .ComputationCPInstruction ;
39+ import org .apache .sysds .runtime .instructions .cp .MMChainCPInstruction ;
40+ import org .apache .sysds .runtime .matrix .operators .AggregateUnaryOperator ;
41+ import org .apache .sysds .runtime .matrix .operators .Operator ;
3342
3443public class FederatedWorkloadAnalyzer {
3544 protected static final Log LOG = LogFactory .getLog (FederatedWorkloadAnalyzer .class .getName ());
@@ -55,7 +64,7 @@ public void incrementWorkload(ExecutionContext ec, long tid, Instruction ins) {
5564 }
5665
5766 public void compressRun (ExecutionContext ec , long tid ) {
58- if (counter >= compressRunFrequency ) {
67+ if (counter >= compressRunFrequency ) {
5968 counter = 0 ;
6069 get (tid ).forEach ((K , V ) -> CompressedMatrixBlockFactory .compressAsync (ec , Long .toString (K ), V ));
6170 }
@@ -68,6 +77,7 @@ private void incrementWorkload(ExecutionContext ec, long tid, ComputationCPInstr
6877 public void incrementWorkload (ExecutionContext ec , ConcurrentHashMap <Long , InstructionTypeCounter > mm ,
6978 ComputationCPInstruction cpIns ) {
7079 // TODO: Count transitive closure via lineage
80+ // TODO: add more operations
7181 if (cpIns instanceof AggregateBinaryCPInstruction ) {
7282 final String n1 = cpIns .input1 .getName ();
7383 MatrixObject d1 = (MatrixObject ) ec .getCacheableData (n1 );
@@ -81,15 +91,48 @@ public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap<Long, Instr
8191 if (validSize (r1 , c1 )) {
8292 getOrMakeCounter (mm , Long .parseLong (n1 )).incRMM (c2 );
8393 // safety add overlapping decompress for RMM
84- getOrMakeCounter (mm , Long .parseLong (n1 )).incOverlappingDecompressions ();
94+ getOrMakeCounter (mm , Long .parseLong (n1 )).incOverlappingDecompressions (c2 );
8595 counter ++;
8696 }
8797 if (validSize (r2 , c2 )) {
8898 getOrMakeCounter (mm , Long .parseLong (n2 )).incLMM (r1 );
8999 counter ++;
90100 }
91-
92101 }
102+ else if (cpIns instanceof MMChainCPInstruction ) {
103+ final String n1 = cpIns .input1 .getName ();
104+ getOrMakeCounter (mm , Long .parseLong (n1 )).incRMM (1 );
105+ getOrMakeCounter (mm , Long .parseLong (n1 )).incLMM (1 );
106+ counter ++;
107+ }
108+ else if (cpIns instanceof AggregateUnaryCPInstruction ) {
109+ Operator op = cpIns .getOperator ();
110+ final String n1 = cpIns .input1 .getName ();
111+ long id = Long .parseLong (n1 );
112+ // MatrixObject d1 = (MatrixObject) ec.getCacheableData(n1);
113+ // int r1 = (int) d1.getDim(0);
114+ // int c1 = (int) d1.getDim(1);
115+ if (op instanceof AggregateUnaryOperator ) {
116+ AggregateUnaryOperator aop = (AggregateUnaryOperator ) op ;
117+ IndexFunction idxF = aop .indexFn ;
118+ getOrMakeCounter (mm , id ).incDictOps ();
119+ if (idxF instanceof ReduceCol ) {
120+ if ((aop .aggOp .increOp .fn instanceof KahanPlus //
121+ || aop .aggOp .increOp .fn instanceof Plus //
122+ || aop .aggOp .increOp .fn instanceof Mean )) {
123+ getOrMakeCounter (mm , id ).incDictOps ();
124+ }
125+ else {
126+ // increment decompression if row reduce.
127+ getOrMakeCounter (mm , id ).incDecompressions ();
128+ }
129+ }
130+ else {
131+ getOrMakeCounter (mm , id ).incDictOps ();
132+ }
133+ }
134+ }
135+
93136 }
94137
95138 private static InstructionTypeCounter getOrMakeCounter (ConcurrentHashMap <Long , InstructionTypeCounter > mm , long id ) {
@@ -117,8 +160,8 @@ private static boolean validSize(int nRow, int nCol) {
117160 return nRow > 90 && nRow >= nCol ;
118161 }
119162
120- @ Override
121- public String toString (){
163+ @ Override
164+ public String toString () {
122165 StringBuilder sb = new StringBuilder ();
123166 sb .append (this .getClass ().getSimpleName ());
124167 sb .append (" Counter: " );
0 commit comments