Skip to content

Commit eb81e3b

Browse files
committed
Add unit tests for FederatedWorkloadAnalyzer workload tracking
Cover the instruction-shape branches in incrementWorkload that drive federated compression decisions, which previously had no direct tests: - AggregateBinary: RMM/LMM counting, overlapping-decompress sizing by the right-hand column count, and the validSize row/column guards - MMChain: one LMM and one RMM contribution per invocation - AggregateUnary: dict-op vs decompression classification across ReduceAll/ReduceRow/ReduceCol with sum, mean, product, and max operators - Instance-level dispatch and compressRun threshold behavior, asserting async compression materializes when the cost model would compress
1 parent fdfc718 commit eb81e3b

1 file changed

Lines changed: 343 additions & 0 deletions

File tree

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.sysds.test.component.federated;
21+
22+
import static org.junit.Assert.assertEquals;
23+
import static org.junit.Assert.assertFalse;
24+
import static org.junit.Assert.assertTrue;
25+
import static org.junit.Assert.fail;
26+
27+
import java.util.concurrent.ConcurrentHashMap;
28+
29+
import org.apache.commons.logging.Log;
30+
import org.apache.commons.logging.LogFactory;
31+
import org.apache.sysds.common.Opcodes;
32+
import org.apache.sysds.common.Types.FileFormat;
33+
import org.apache.sysds.common.Types.ValueType;
34+
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
35+
import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
36+
import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter;
37+
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
38+
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
39+
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
40+
import org.apache.sysds.runtime.controlprogram.federated.FederatedWorkloadAnalyzer;
41+
import org.apache.sysds.runtime.instructions.Instruction;
42+
import org.apache.sysds.runtime.instructions.InstructionUtils;
43+
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
44+
import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
45+
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
46+
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
47+
import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
48+
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
49+
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
50+
import org.apache.sysds.runtime.meta.MetaDataFormat;
51+
import org.apache.sysds.test.TestUtils;
52+
import org.junit.Test;
53+
54+
public class FederatedWorkloadAnalyzerTest {
55+
protected static final Log LOG = LogFactory.getLog(FederatedWorkloadAnalyzerTest.class.getName());
56+
57+
/** Async compression triggered by compressRun runs on a worker thread, so poll instead of sleeping. */
58+
private static final int COMPRESS_TIMEOUT_MS = 10000;
59+
60+
private final FederatedWorkloadAnalyzer analyzer = new FederatedWorkloadAnalyzer();
61+
62+
// --------------------------------------------------------------------------------------------
63+
// AggregateBinary (matrix multiply)
64+
// --------------------------------------------------------------------------------------------
65+
66+
@Test
67+
public void aggregateBinaryBothSidesCounted() {
68+
// left 100x100 (valid), right 100x50 (valid)
69+
ExecutionContext ec = ec("1", mo(100, 100), "2", mo(100, 50));
70+
ConcurrentHashMap<Long, InstructionTypeCounter> mm = new ConcurrentHashMap<>();
71+
72+
analyzer.incrementWorkload(ec, mm, mm("1", "2"));
73+
74+
// left side: RMM with the right-hand column count, plus overlapping decompress sized by c2
75+
InstructionTypeCounter left = mm.get(1L);
76+
assertEquals(50, left.getRightMultiplications());
77+
assertEquals(50, left.getOverlappingDecompressions());
78+
// right side: LMM with the left-hand row count
79+
InstructionTypeCounter right = mm.get(2L);
80+
assertEquals(100, right.getLeftMultiplications());
81+
}
82+
83+
@Test
84+
public void aggregateBinaryOnlyLeftCountedWhenRightTooSmall() {
85+
// left 100x10 (valid), right 10x5 (too few rows -> invalid)
86+
ExecutionContext ec = ec("1", mo(100, 10), "2", mo(10, 5));
87+
ConcurrentHashMap<Long, InstructionTypeCounter> mm = new ConcurrentHashMap<>();
88+
89+
analyzer.incrementWorkload(ec, mm, mm("1", "2"));
90+
91+
InstructionTypeCounter left = mm.get(1L);
92+
assertEquals(5, left.getRightMultiplications());
93+
assertEquals(5, left.getOverlappingDecompressions());
94+
// right side never tracked because it does not pass validSize
95+
assertFalse(mm.containsKey(2L));
96+
}
97+
98+
@Test
99+
public void aggregateBinaryNeitherCountedWhenBothTooSmall() {
100+
ExecutionContext ec = ec("1", mo(10, 10), "2", mo(10, 10));
101+
ConcurrentHashMap<Long, InstructionTypeCounter> mm = new ConcurrentHashMap<>();
102+
103+
analyzer.incrementWorkload(ec, mm, mm("1", "2"));
104+
105+
assertTrue(mm.isEmpty());
106+
}
107+
108+
@Test
109+
public void aggregateBinaryWideOperandNotCounted() {
110+
// 100x200: enough rows (>90) but more columns than rows -> validSize false on the second clause
111+
ExecutionContext ec = ec("1", mo(100, 200), "2", mo(10, 5));
112+
ConcurrentHashMap<Long, InstructionTypeCounter> mm = new ConcurrentHashMap<>();
113+
114+
analyzer.incrementWorkload(ec, mm, mm("1", "2"));
115+
116+
assertTrue(mm.isEmpty());
117+
}
118+
119+
// --------------------------------------------------------------------------------------------
120+
// MMChain
121+
// --------------------------------------------------------------------------------------------
122+
123+
@Test
124+
public void mmChainCountsOneLeftAndOneRight() {
125+
ConcurrentHashMap<Long, InstructionTypeCounter> mm = new ConcurrentHashMap<>();
126+
127+
analyzer.incrementWorkload(null, mm, mmchain("1"));
128+
129+
InstructionTypeCounter c = mm.get(1L);
130+
assertEquals(1, c.getRightMultiplications());
131+
assertEquals(1, c.getLeftMultiplications());
132+
}
133+
134+
// --------------------------------------------------------------------------------------------
135+
// AggregateUnary
136+
// --------------------------------------------------------------------------------------------
137+
138+
@Test
139+
public void aggregateUnaryColSumsIsDictOp() {
140+
// colSums -> ReduceRow -> compression friendly (2 dict ops, no decompress)
141+
assertDictOpsAndDecompress(Opcodes.UACKP.toString(), 2, 0);
142+
}
143+
144+
@Test
145+
public void aggregateUnaryFullSumIsDictOp() {
146+
// sum -> ReduceAll -> compression friendly (2 dict ops, no decompress)
147+
assertDictOpsAndDecompress(Opcodes.UAKP.toString(), 2, 0);
148+
}
149+
150+
@Test
151+
public void aggregateUnaryRowSumsIsDictOp() {
152+
// rowSums -> ReduceCol with KahanPlus -> compression friendly (2 dict ops, no decompress)
153+
assertDictOpsAndDecompress(Opcodes.UARKP.toString(), 2, 0);
154+
}
155+
156+
@Test
157+
public void aggregateUnaryRowMeansIsDictOp() {
158+
// rowMeans -> ReduceCol with Mean -> compression friendly (2 dict ops, no decompress)
159+
assertDictOpsAndDecompress(Opcodes.UARMEAN.toString(), 2, 0);
160+
}
161+
162+
@Test
163+
public void aggregateUnaryRowProductsForcesDecompress() {
164+
// rowProds -> ReduceCol with Multiply -> not friendly (1 dict op + 1 decompress)
165+
assertDictOpsAndDecompress(Opcodes.UARM.toString(), 1, 1);
166+
}
167+
168+
@Test
169+
public void aggregateUnaryRowSumsPlusIsDictOp() {
170+
// rowSums (plain Plus, no Kahan) -> ReduceCol with Plus -> compression friendly (2 dict ops)
171+
assertDictOpsAndDecompress(Opcodes.UARP.toString(), 2, 0);
172+
}
173+
174+
@Test
175+
public void aggregateUnaryRowMaxForcesDecompress() {
176+
// rowMax -> ReduceCol with Builtin max -> not friendly (1 dict op + 1 decompress)
177+
assertDictOpsAndDecompress(Opcodes.UARMAX.toString(), 1, 1);
178+
}
179+
180+
@Test
181+
public void aggregateUnaryNonAggregateOperatorIgnored() {
182+
// nrow uses a SimpleOperator (not an AggregateUnaryOperator) so nothing is tracked
183+
ConcurrentHashMap<Long, InstructionTypeCounter> mm = new ConcurrentHashMap<>();
184+
185+
analyzer.incrementWorkload(null, mm, uagg(Opcodes.NROW.toString(), "1"));
186+
187+
assertTrue(mm.isEmpty());
188+
}
189+
190+
private void assertDictOpsAndDecompress(String opcode, int expectedDictOps, int expectedDecompress) {
191+
ConcurrentHashMap<Long, InstructionTypeCounter> mm = new ConcurrentHashMap<>();
192+
193+
analyzer.incrementWorkload(null, mm, uagg(opcode, "1"));
194+
195+
InstructionTypeCounter c = mm.get(1L);
196+
assertEquals("Unexpected dict-ops for " + opcode, expectedDictOps, c.getDictionaryOps());
197+
assertEquals("Unexpected decompressions for " + opcode, expectedDecompress, c.getDecompressions());
198+
}
199+
200+
// --------------------------------------------------------------------------------------------
201+
// Instance level dispatch + async compress trigger
202+
// --------------------------------------------------------------------------------------------
203+
204+
@Test
205+
public void compressRunCompressesAfterEnoughWorkload() {
206+
final long tid = 1;
207+
final int dim = 100, iter = 10;
208+
// Right operand is left-multiplied each matmul, accumulating LMM = leftRows (=dim) per
209+
// invocation, so iter=10 yields LMM=1000 on a 100x100 rounded block. This mirrors the shape
210+
// and counter that FedWorkerMatrixMultiplyWorkload relies on to trigger compression.
211+
MatrixBlock rightBlock = TestUtils.round(TestUtils.generateTestMatrixBlock(dim, dim, 0.5, 2.5, 1.0, 222));
212+
MatrixBlock probeBlock = new MatrixBlock();
213+
probeBlock.copy(rightBlock);
214+
215+
MatrixObject left = compressibleMO(dim, dim, 7);
216+
MatrixObject right = wrap(rightBlock);
217+
ExecutionContext ec = ec("1", left, "2", right);
218+
219+
// each matmul with two valid sides increments the counter twice; reaching the
220+
// compressRunFrequency threshold of 10 schedules an async compression pass
221+
ComputationCPInstruction ins = mm("1", "2");
222+
for(int i = 0; i < iter; i++)
223+
analyzer.incrementWorkload(ec, tid, ins);
224+
225+
analyzer.compressRun(ec, tid);
226+
227+
// Only assert the async compression materialized if the cost model would compress this shape
228+
// locally; otherwise the workload pass legitimately leaves it uncompressed (matches the skip
229+
// pattern in FedWorkerMatrixMultiplyWorkload).
230+
InstructionTypeCounter probe = new InstructionTypeCounter(0, 0, 0, dim * iter, 0, 0, 0, 0, false);
231+
boolean locallyCompressible = CompressedMatrixBlockFactory.compress(probeBlock, probe)
232+
.getLeft() instanceof CompressedMatrixBlock;
233+
if(locallyCompressible)
234+
assertCompressedWithinTimeout(right);
235+
}
236+
237+
@Test
238+
public void compressRunNoOpBelowThreshold() {
239+
final long tid = 2;
240+
MatrixObject left = compressibleMO(500, 10, 7);
241+
MatrixObject right = compressibleMO(500, 10, 13);
242+
ExecutionContext ec = ec("1", left, "2", right);
243+
244+
// only two invocations -> counter = 4, below threshold, so nothing compresses
245+
ComputationCPInstruction ins = mm("1", "2");
246+
analyzer.incrementWorkload(ec, tid, ins);
247+
analyzer.incrementWorkload(ec, tid, ins);
248+
249+
analyzer.compressRun(ec, tid);
250+
251+
assertFalse(left.acquireReadAndRelease() instanceof CompressedMatrixBlock);
252+
assertFalse(right.acquireReadAndRelease() instanceof CompressedMatrixBlock);
253+
}
254+
255+
@Test
256+
public void nonComputationInstructionIgnored() {
257+
// the public entry point silently ignores non-CP / non-computation instructions
258+
analyzer.incrementWorkload(null, 99, (Instruction) null);
259+
analyzer.compressRun(null, 99);
260+
}
261+
262+
@Test
263+
public void unhandledComputationInstructionIgnored() {
264+
// a transpose is a ComputationCPInstruction but none of the tracked shapes -> no counters
265+
ConcurrentHashMap<Long, InstructionTypeCounter> mm = new ConcurrentHashMap<>();
266+
267+
analyzer.incrementWorkload(null, mm, reorg("1"));
268+
269+
assertTrue(mm.isEmpty());
270+
}
271+
272+
@Test
273+
public void toStringReportsState() {
274+
String s = analyzer.toString();
275+
assertTrue(s.contains(FederatedWorkloadAnalyzer.class.getSimpleName()));
276+
assertTrue(s.contains("Counter"));
277+
}
278+
279+
// --------------------------------------------------------------------------------------------
280+
// helpers
281+
// --------------------------------------------------------------------------------------------
282+
283+
private static void assertCompressedWithinTimeout(MatrixObject mo) {
284+
final long deadline = System.currentTimeMillis() + COMPRESS_TIMEOUT_MS;
285+
while(System.currentTimeMillis() < deadline) {
286+
if(mo.acquireReadAndRelease() instanceof CompressedMatrixBlock)
287+
return;
288+
try {
289+
Thread.sleep(50);
290+
}
291+
catch(InterruptedException e) {
292+
Thread.currentThread().interrupt();
293+
fail("Interrupted while waiting for async compression");
294+
}
295+
}
296+
fail("Matrix was not compressed by the workload analyzer within " + COMPRESS_TIMEOUT_MS + "ms");
297+
}
298+
299+
private static ExecutionContext ec(String n1, MatrixObject m1, String n2, MatrixObject m2) {
300+
LocalVariableMap vars = new LocalVariableMap();
301+
ExecutionContext ec = new ExecutionContext(vars);
302+
ec.setVariable(n1, m1);
303+
ec.setVariable(n2, m2);
304+
return ec;
305+
}
306+
307+
/** Build a MatrixObject of the requested shape (data content irrelevant for the counters). */
308+
private static MatrixObject mo(int rows, int cols) {
309+
return wrap(new MatrixBlock(rows, cols, 0.0));
310+
}
311+
312+
private static MatrixObject compressibleMO(int rows, int cols, int seed) {
313+
return wrap(TestUtils.round(TestUtils.generateTestMatrixBlock(rows, cols, 0, 3, 1.0, seed)));
314+
}
315+
316+
private static MatrixObject wrap(MatrixBlock mb) {
317+
MatrixCharacteristics mc = new MatrixCharacteristics(mb.getNumRows(), mb.getNumColumns(), -1, mb.getNonZeros());
318+
MetaDataFormat md = new MetaDataFormat(mc, FileFormat.BINARY);
319+
MatrixObject mo = new MatrixObject(ValueType.FP64, "/dev/null", md, mb);
320+
mo.getDataCharacteristics().setDimension(mb.getNumRows(), mb.getNumColumns());
321+
return mo;
322+
}
323+
324+
private static ComputationCPInstruction mm(String in1, String in2) {
325+
String str = InstructionUtils.concatOperands("CP", Opcodes.MMULT.toString(), in1, in2, "3", "16");
326+
return AggregateBinaryCPInstruction.parseInstruction(str);
327+
}
328+
329+
private static ComputationCPInstruction mmchain(String in1) {
330+
String str = InstructionUtils.concatOperands("CP", Opcodes.MMCHAIN.toString(), in1, "2", "3", "XtXv", "16");
331+
return MMChainCPInstruction.parseInstruction(str);
332+
}
333+
334+
private static ComputationCPInstruction uagg(String opcode, String in1) {
335+
String str = InstructionUtils.concatOperands("CP", opcode, in1, "2", "16");
336+
return AggregateUnaryCPInstruction.parseInstruction(str);
337+
}
338+
339+
private static ComputationCPInstruction reorg(String in1) {
340+
String str = InstructionUtils.concatOperands("CP", Opcodes.TRANSPOSE.toString(), in1, "2", "16");
341+
return ReorgCPInstruction.parseInstruction(str);
342+
}
343+
}

0 commit comments

Comments
 (0)