From 3e55f991e3918d17015b69185c65bdc63cab7ea4 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 30 Dec 2024 22:58:46 +0100 Subject: [PATCH 01/81] [SYSTEMDS-3644] Compressed transform encode --- .../runtime/compress/colgroup/ColGroupIO.java | 4 +- .../compress/lib/CLALibDecompress.java | 22 +- .../compress/lib/CLALibRightMultBy.java | 9 +- .../sysds/runtime/frame/data/FrameBlock.java | 4 +- .../frame/data/columns/ABooleanArray.java | 12 +- .../frame/data/columns/ACompressedArray.java | 5 - .../runtime/frame/data/columns/Array.java | 206 ++++++-- .../frame/data/columns/ArrayFactory.java | 2 +- .../frame/data/columns/BitSetArray.java | 5 - .../frame/data/columns/BooleanArray.java | 5 - .../runtime/frame/data/columns/CharArray.java | 5 - .../runtime/frame/data/columns/DDCArray.java | 50 +- .../frame/data/columns/DoubleArray.java | 5 - .../frame/data/columns/FloatArray.java | 5 - .../frame/data/columns/HashIntegerArray.java | 30 +- .../frame/data/columns/HashLongArray.java | 27 +- .../frame/data/columns/IntegerArray.java | 7 +- .../runtime/frame/data/columns/LongArray.java | 9 +- .../frame/data/columns/OptionalArray.java | 40 +- .../frame/data/columns/RaggedArray.java | 6 +- .../frame/data/columns/StringArray.java | 16 +- .../compress/CompressedFrameBlockFactory.java | 6 +- ...turnParameterizedBuiltinSPInstruction.java | 4 +- .../runtime/matrix/data/MatrixBlock.java | 9 +- .../transform/encode/ColumnEncoder.java | 23 +- .../encode/ColumnEncoderBagOfWords.java | 22 +- .../transform/encode/ColumnEncoderBin.java | 55 ++- .../encode/ColumnEncoderComposite.java | 14 +- .../encode/ColumnEncoderDummycode.java | 19 +- .../encode/ColumnEncoderFeatureHash.java | 19 +- .../encode/ColumnEncoderPassThrough.java | 77 +-- .../transform/encode/ColumnEncoderRecode.java | 45 +- .../encode/ColumnEncoderWordEmbedding.java | 13 +- .../transform/encode/CompressedEncode.java | 459 ++++++++++++------ .../transform/encode/EncoderMVImpute.java | 2 +- .../transform/encode/MultiColumnEncoder.java | 70 ++- .../sysds/runtime/util/CommonThreadPool.java | 209 ++++++-- .../java/org/apache/sysds/test/TestUtils.java | 9 + .../component/frame/FrameApplySchema.java | 6 +- .../frame/array/CustomArrayTests.java | 34 +- .../frame/array/FrameArrayTests.java | 29 +- .../component/frame/array/RecodeMapTest.java | 106 ++++ .../TransformCompressedTestLogger.java | 13 +- .../TransformCompressedTestMultiCol.java | 14 + .../TransformCompressedTestSingleCol.java | 20 +- .../{ThreadPool.java => ThreadPoolTests.java} | 257 +++++++++- .../component/resource/RecompilationTest.java | 2 +- .../part1/BuiltinAdasynRealDataTest.java | 2 +- .../ColumnEncoderSerializationTest.java | 10 +- 49 files changed, 1428 insertions(+), 594 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/component/frame/array/RecodeMapTest.java rename src/test/java/org/apache/sysds/test/component/misc/{ThreadPool.java => ThreadPoolTests.java} (69%) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java index 92bb7f550d8..91442281317 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupIO.java @@ -94,7 +94,9 @@ public static long getExactSizeOnDisk(List colGroups) { } ret += grp.getExactSizeOnDisk(); } - LOG.error(" duplicate dicts on exact Size on Disk : " + (colGroups.size() - dicts.size()) ); + if(LOG.isWarnEnabled()) + LOG.warn(" duplicate dicts on exact Size on Disk : " + (colGroups.size() - dicts.size()) ); + return ret; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java index 1f0c9d9fc4e..e754ee6b1e3 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibDecompress.java @@ -237,7 +237,7 @@ private static MatrixBlock decompressExecute(CompressedMatrixBlock cmb, int k) { if(ret.isInSparseFormat()) decompressSparseSingleThread(ret, filteredGroups, nRows, blklen); else - decompressDenseSingleThread(ret, filteredGroups, nRows, blklen, constV, eps, nonZeros, overlapping); + decompressDenseSingleThread(ret, filteredGroups, nRows, blklen, constV, eps, overlapping); } else if(ret.isInSparseFormat()) decompressSparseMultiThread(ret, filteredGroups, nRows, blklen, k); @@ -294,7 +294,7 @@ private static void decompressSparseSingleThread(MatrixBlock ret, List filteredGroups, int rlen, - int blklen, double[] constV, double eps, long nonZeros, boolean overlapping) { + int blklen, double[] constV, double eps, boolean overlapping) { final DenseBlock db = ret.getDenseBlock(); final int nCol = ret.getNumColumns(); @@ -308,21 +308,19 @@ private static void decompressDenseSingleThread(MatrixBlock ret, List } } - // private static void decompressDenseMultiThread(MatrixBlock ret, List groups, double[] constV, int k, - // boolean overlapping) { - // final int nRows = ret.getNumRows(); - // final double eps = getEps(constV); - // final int blklen = Math.max(nRows / k, 512); - // decompressDenseMultiThread(ret, groups, nRows, blklen, constV, eps, k, overlapping); - // } - - protected static void decompressDenseMultiThread(MatrixBlock ret, List groups, double[] constV, + public static void decompressDense(MatrixBlock ret, List groups, double[] constV, double eps, int k, boolean overlapping) { Timing time = new Timing(true); final int nRows = ret.getNumRows(); final int blklen = Math.max(nRows / k, 512); - decompressDenseMultiThread(ret, groups, nRows, blklen, constV, eps, k, overlapping); + if( k > 1) + decompressDenseMultiThread(ret, groups, nRows, blklen, constV, eps, k, overlapping); + else + decompressDenseSingleThread(ret, groups, nRows, blklen, constV, eps, overlapping); + + ret.recomputeNonZeros(k); + if(DMLScript.STATISTICS) { final double t = time.stop(); DMLCompressionStatistics.addDecompressTime(t, k); diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java index 597c38bf9ac..5d6de813fcf 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java @@ -189,15 +189,8 @@ private static MatrixBlock RMM(CompressedMatrixBlock m1, MatrixBlock that, int k constV = mmTemp.isEmpty() ? null : mmTemp.getDenseBlockValues(); } - final Timing time = new Timing(true); - ret = asyncRet(f); - CLALibDecompress.decompressDenseMultiThread(ret, retCg, constV, 0, k, true); - - if(DMLScript.STATISTICS) { - final double t = time.stop(); - DMLCompressionStatistics.addDecompressTime(t, k); - } + CLALibDecompress.decompressDense(ret, retCg, constV, 0, k, true); return ret; } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index 61875d2e140..7566ba2fd55 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -1267,8 +1267,8 @@ public void copy(int rl, int ru, int cl, int cu, FrameBlock src) { * @param col is the column # from frame data which contains Recode map generated earlier. * @return map of token and code for every element in the input column of a frame containing Recode map */ - public Map getRecodeMap(int col) { - return _coldata[col].getRecodeMap(); + public Map getRecodeMap(int col) { + return _coldata[col].getRecodeMap(4); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java index e3fcb2c9f63..848bc38796b 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java @@ -21,6 +21,7 @@ import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ExecutorService; public abstract class ABooleanArray extends Array { @@ -55,13 +56,14 @@ public boolean possiblyContainsNaN() { * @param value The string array to set from. */ public abstract void setNullsFromString(int rl, int ru, Array value); - + + @Override - protected Map createRecodeMap() { - Map map = new HashMap<>(); - long id = 1; + protected Map createRecodeMap(int estimate, ExecutorService pool) { + Map map = new HashMap<>(); + int id = 1; for(int i = 0; i < size() && id <= 2; i++) { - Long v = map.putIfAbsent(get(i), id); + Integer v = map.putIfAbsent(get(i), id); if(v == null) id++; } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ACompressedArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ACompressedArray.java index 50059999676..81f531fa831 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ACompressedArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ACompressedArray.java @@ -59,11 +59,6 @@ public void setFromOtherType(int rl, int ru, Array value) { throw new DMLCompressionException("Invalid to set value in CompressedArray"); } - @Override - public void set(int rl, int ru, Array value, int rlSrc) { - throw new DMLCompressionException("Invalid to set value in CompressedArray"); - } - @Override public void setNz(int rl, int ru, Array value) { throw new DMLCompressionException("Invalid to set value in CompressedArray"); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index 15c9f371ea0..0757d24b525 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -20,20 +20,28 @@ package org.apache.sysds.runtime.frame.data.columns; import java.lang.ref.SoftReference; +import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.Iterator; +import java.util.List; import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.hadoop.io.Writable; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.compress.estim.sample.SampleEstimatorFactory; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; import org.apache.sysds.runtime.frame.data.compress.ArrayCompressionStatistics; import org.apache.sysds.runtime.matrix.data.Pair; +import org.apache.sysds.utils.stats.Timing; /** * Generic, resizable native arrays for the internal representation of the columns in the FrameBlock. We use this custom @@ -42,8 +50,11 @@ public abstract class Array implements Writable { protected static final Log LOG = LogFactory.getLog(Array.class.getName()); + /** Parallelization threshold for parallelizing vector operations */ + public static int ROW_PARALLELIZATION_THRESHOLD = 10000; + /** A soft reference to a memorization of this arrays mapping, used in transformEncode */ - protected SoftReference> _rcdMapCache = null; + protected SoftReference> _rcdMapCache = null; /** The current allocated number of elements in this Array */ protected int _size; @@ -63,7 +74,7 @@ protected int newSize() { * * @return The cached recode map */ - public final SoftReference> getCache() { + public final SoftReference> getCache() { return _rcdMapCache; } @@ -72,7 +83,7 @@ public final SoftReference> getCache() { * * @param m The element to cache. */ - public final void setCache(SoftReference> m) { + public final void setCache(SoftReference> m) { _rcdMapCache = m; } @@ -83,16 +94,49 @@ public final void setCache(SoftReference> m) { * * @return A recode map */ - public synchronized final Map getRecodeMap() { + public synchronized final Map getRecodeMap() { + return getRecodeMap(4); + } + + /** + * Get a recode map that maps each unique value in the array, to a long ID. Null values are ignored, and not included + * in the mapping. The resulting recode map in stored in a soft reference to speed up repeated calls to the same + * column. + * + * @param estimate The estimated number of unique values in this Array to start the initial hashmap size at + * @return A recode map + */ + public synchronized final Map getRecodeMap(int estimate) { + try { + return getRecodeMap(estimate, null); + } + catch(Exception e) { + throw new RuntimeException(e); + } + } + + /** + * Get a recode map that maps each unique value in the array, to a long ID. Null values are ignored, and not included + * in the mapping. The resulting recode map in stored in a soft reference to speed up repeated calls to the same + * column. + * + * @param estimate the estimated number of unique values in this array. + * @param pool An executor pool to be used for parallel execution (Note this method does not shutdown the pool) + * @return A recode map + * @throws ExecutionException if the parallel execution fails + * @throws InterruptedException if the parallel execution fails + */ + public synchronized final Map getRecodeMap(int estimate, ExecutorService pool) + throws InterruptedException, ExecutionException { // probe cache for existing map - Map map; - SoftReference> tmp = getCache(); + Map map; + SoftReference> tmp = getCache(); map = (tmp != null) ? tmp.get() : null; if(map != null) return map; // construct recode map - map = createRecodeMap(); + map = createRecodeMap(estimate, pool); // put created map into cache setCache(new SoftReference<>(map)); @@ -101,25 +145,99 @@ public synchronized final Map getRecodeMap() { } /** - * Recreate the recode map from what is inside array. This is an internal method for arrays, and the result is cached - * in the main class of the arrays. + * Get a recode map that maps each unique value in the array, to a long ID. Null values are ignored, and not included + * in the mapping. The resulting recode map in stored in a soft reference to speed up repeated calls to the same + * column. + * + * @param estimate The estimate number of unique values inside this array. + * @param pool The thread pool to use for parallel creation of recode map (can be null). (Note this method does + * not shutdown the pool) + * @return The recode map created. + * @throws ExecutionException if the parallel execution fails + * @throws InterruptedException if the parallel execution fails + */ + protected Map createRecodeMap(int estimate, ExecutorService pool) + throws InterruptedException, ExecutionException { + Timing t = new Timing(); + final int s = size(); + int k = OptimizerUtils.getTransformNumThreads(); + Map ret; + if(pool == null || s < ROW_PARALLELIZATION_THRESHOLD) + ret = createRecodeMap(estimate, 0, s); + else + ret = parallelCreateRecodeMap(estimate, pool, s, k); + + if(LOG.isDebugEnabled()) { + String base = "CreateRecodeMap estimate: %10d actual %10d time: %10.5f"; + LOG.debug(String.format(base, estimate, ret.size(), t.stop())); + } + return ret; + } + + private Map parallelCreateRecodeMap(int estimate, ExecutorService pool, final int s, int k) + throws InterruptedException, ExecutionException { + + final int blk = Math.max(ROW_PARALLELIZATION_THRESHOLD / 2, (s + k) / k); + final List>> tasks = new ArrayList<>(); + for(int i = blk; i < s; i += blk) { // start at blk for the other threads + final int start = i; + final int end = Math.min(i + blk, s); + tasks.add(pool.submit(() -> createRecodeMap(estimate, start, end))); + } + // make the initial map thread local allocation. + final Map map = new HashMap<>((int) (estimate * 1.3)); + createRecodeMap(map, 0, blk); + for(int i = 0; i < tasks.size(); i++) { // merge with other threads work. + final Map map2 = tasks.get(i).get(); + mergeRecodeMaps(map, map2); + } + return map; + + } + + /** + * Merge Recode maps, most likely from parallel threads. + * + * If the unique value is present in the target, use that ID, otherwise this method map to new ID's based on the + * target mapping's size. * - * @return The recode map + * @param target The target object to merge the two maps into + * @param from The Map to take entries from. */ - protected Map createRecodeMap() { - Map map = new HashMap<>(); - long id = 1; - for(int i = 0; i < size(); i++) { - T val = get(i); - if(val != null) { - Long v = map.putIfAbsent(val, id); - if(v == null) - id++; - } + protected static void mergeRecodeMaps(Map target, Map from) { + final List fromEntriesOrdered = new ArrayList<>(Collections.nCopies(from.size(), null)); + for(Map.Entry e : from.entrySet()) + fromEntriesOrdered.set(e.getValue() - 1, e.getKey()); + int id = target.size(); + for(T e : fromEntriesOrdered) { + if(target.putIfAbsent(e, id) == null) + id++; } + } + + private Map createRecodeMap(final int estimate, final int s, final int e) { + // * 1.3 because we hashMap has a load factor of 1.75 + final Map map = new HashMap<>((int) (Math.min((long) estimate, (e - s)) * 1.3)); + return createRecodeMap(map, s, e); + } + + private Map createRecodeMap(Map map, final int s, final int e) { + int id = 1; + for(int i = s; i < e; i++) + id = addValRecodeMap(map, id, i); return map; } + protected int addValRecodeMap(Map map, int id, int i) { + T val = getInternal(i); + if(val != null) { + Integer v = map.putIfAbsent(val, id); + if(v == null) + id++; + } + return id; + } + /** * Get the number of elements in the array, this does not necessarily reflect the current allocated size. * @@ -224,14 +342,11 @@ public double getAsNaNDouble(int i) { * * @param rl row lower * @param ru row upper (inclusive) - * @param value value array to take values from (same type) + * @param value value array to take values from (same type) offset by rl. */ - public abstract void set(int rl, int ru, Array value); - - // { - // for(int i = rl; i <= ru; i++) - // set(i, value.get(i)); - // } + public final void set(int rl, int ru, Array value) { + set(rl, ru, value, 0); + } /** * Set range to given arrays value with an offset into other array @@ -243,7 +358,7 @@ public double getAsNaNDouble(int i) { */ public void set(int rl, int ru, Array value, int rlSrc) { for(int i = rl, off = rlSrc; i <= ru; i++, off++) - set(i, value.get(off)); + set(i, value.getInternal(off)); } /** @@ -918,4 +1033,39 @@ public double[] minMax(int l, int u) { } return new double[] {min, max}; } + + /** + * Set the index i in the map given based on the mapping provided. The map should be guaranteed to contain all unique + * values. + * + * @param map A map containing all unique values of this array + * @param m The MapToData to set the value part of the Map from + * @param i The index to set in m + */ + public void setM(Map map, AMapToData m, int i) { + m.set(i, map.get(getInternal(i)).intValue() - 1); + } + + /** + * Set the index i in the map given based on the mapping provided. The map should be guaranteed to contain all unique + * values except null. Therefore in case of null we set the provided si value. + * + * @param map A map containing all unique values of this array + * @param si The default value to use in m if this Array contains null at index i + * @param m The MapToData to set the value part of the Map from + * @param i The index to set in m + */ + public void setM(Map map, int si, AMapToData m, int i) { + try { + final T v = getInternal(i); + if(v != null) + m.set(i, map.get(v).intValue() - 1); + else + m.set(i, si); + } + catch(Exception e) { + String error = "expected: " + getInternal(i) + " to be in map: " + map; + throw new RuntimeException(error, e); + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java index 3f21a8f066e..5f2d08a122f 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ArrayFactory.java @@ -356,7 +356,7 @@ else if(target.getFrameArrayType() != FrameArrayType.OPTIONAL // Array targetC = (Array) (ta != tc ? target.changeType(tc) : target); Array srcC = (Array) (tb != tc ? src.changeType(tc) : src); - targetC.set(rl, ru, srcC); + targetC.set(rl, ru, srcC, 0); return targetC; } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java index 133c0f956c5..4ddc66e8de3 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java @@ -147,11 +147,6 @@ public void set(int index, String value) { set(index, BooleanArray.parseBoolean(value)); } - @Override - public void set(int rl, int ru, Array value) { - set(rl, ru, value, 0); - } - @Override public void setFromOtherType(int rl, int ru, Array value) { final ValueType vt = value.getValueType(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java index 041a780e527..edf4ee7e685 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BooleanArray.java @@ -69,11 +69,6 @@ public void set(int index, String value) { set(index, BooleanArray.parseBoolean(value)); } - @Override - public void set(int rl, int ru, Array value) { - set(rl, ru, value, 0); - } - @Override public void setFromOtherType(int rl, int ru, Array value) { final ValueType vt = value.getValueType(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java index ae57ae167b3..a4192c6440f 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java @@ -100,11 +100,6 @@ public void setFromOtherType(int rl, int ru, Array value) { _data[i] = value.get(i).toString().charAt(0); } - @Override - public void set(int rl, int ru, Array value) { - set(rl, ru, value, 0); - } - @Override public void set(int rl, int ru, Array value, int rlSrc) { try { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java index f04093f9de4..ecd827070b3 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java @@ -25,6 +25,8 @@ import java.util.HashMap; import java.util.Map; import java.util.Map.Entry; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; @@ -53,7 +55,8 @@ public DDCArray(Array dict, AMapToData map) { if(FrameBlock.debug) { if(dict != null && dict.size() != map.getUnique()) - throw new DMLRuntimeException("Invalid DDCArray, dictionary size is not equal to map unique"); + LOG.warn("Invalid DDCArray, dictionary size (" + dict.size() + ") is not equal to map unique (" + + map.getUnique() + ")"); } } @@ -175,8 +178,9 @@ public static Array compressToDDC(Array arr, int estimateUnique) { } @Override - protected Map createRecodeMap() { - return dict.createRecodeMap(); + protected Map createRecodeMap(int estimate, ExecutorService pool) + throws InterruptedException, ExecutionException { + return dict.createRecodeMap(estimate, pool); } @Override @@ -210,6 +214,11 @@ public T get(int index) { return dict.get(map.getIndex(index)); } + @Override + public T getInternal(int index) { + return dict.getInternal(map.getIndex(index)); + } + @Override public double[] extractDouble(double[] ret, int rl, int ru) { // overridden to allow GIT compile @@ -255,30 +264,29 @@ public Pair analyzeValueType(int maxCells) { } @Override - public void set(int rl, int ru, Array value) { + public void set(int rl, int ru, Array value, int rlSrc) { if(value instanceof DDCArray) { DDCArray dc = (DDCArray) value; - // we allow one side to have a null dictionary while the other does not. - if((dict != null && dc.dict != null // If both dicts are not null - && (dc.dict.size() != dict.size() // then if size of the dicts are not equivalent - || (FrameBlock.debug && !dc.dict.equals(dict))) // or then if debugging do full equivalence check - ) || map.getUnique() < dc.map.getUnique() // this map is not able to contain values of other. - ) - throw new DMLCompressionException("Invalid setting of DDC Array, of incompatible instance." + // - "\ndict1 is null: " + (dict == null) + // - "\ndict2 is null: " + (dc.dict == null) +// - "\nmap1 unique: " + (map.getUnique()) + // - "\nmap2 unique: " + (dc.map.getUnique()) ); - - final AMapToData tm = dc.map; - for(int i = rl; i <= ru; i++) { - map.set(i, tm.getIndex(i)); - } + checkCompressedSet(dc); + map.set(rl, ru + 1, rlSrc, dc.map); } else throw new DMLCompressionException("Invalid to set value in CompressedArray"); } + private void checkCompressedSet(DDCArray dc) { + if((dict != null && dc.dict != null // If both dicts are not null + && (dc.dict.size() != dict.size() // then if size of the dicts are not equivalent + || (FrameBlock.debug && !dc.dict.equals(dict))) // or then if debugging do full equivalence check + ) || map.getUnique() < dc.map.getUnique() // this map is not able to contain values of other. + ) + throw new DMLCompressionException("Invalid setting of DDC Array, of incompatible instance." + // + "\ndict1 is null: " + (dict == null) + // + "\ndict2 is null: " + (dc.dict == null) + // + "\nmap1 unique: " + (map.getUnique()) + // + "\nmap2 unique: " + (dc.map.getUnique())); + } + @Override public FrameArrayType getFrameArrayType() { return FrameArrayType.DDC; @@ -393,7 +401,7 @@ else if(l > dict.size()) @Override public ArrayCompressionStatistics statistics(int nSamples) { - final long memSize = getInMemorySize(); + final long memSize = getInMemorySize(); final int memSizePerElement = estMemSizePerElement(getValueType(), memSize); return new ArrayCompressionStatistics(memSizePerElement, // diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java index 23f58798249..99cce9f9e97 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java @@ -71,11 +71,6 @@ public void set(int index, String value) { set(index, parseDouble(value)); } - @Override - public void set(int rl, int ru, Array value) { - set(rl, ru, value, 0); - } - @Override public void setFromOtherType(int rl, int ru, Array value) { final ValueType vt = value.getValueType(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java index fc1a7aed5ae..d586c2f32a8 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java @@ -73,11 +73,6 @@ public void set(int index, String value) { set(index, parseFloat(value)); } - @Override - public void set(int rl, int ru, Array value) { - set(rl, ru, value, 0); - } - @Override public void setFromOtherType(int rl, int ru, Array value) { final ValueType vt = value.getValueType(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashIntegerArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashIntegerArray.java index 131036d2085..328c9a565fe 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashIntegerArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashIntegerArray.java @@ -23,10 +23,12 @@ import java.io.DataOutput; import java.io.IOException; import java.util.Arrays; +import java.util.Map; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.util.UtilFunctions; @@ -103,11 +105,6 @@ public void set(int index, double value) { _data[index] = (int) value; } - @Override - public void set(int rl, int ru, Array value) { - set(rl, ru, value, 0); - } - @Override public void setFromOtherType(int rl, int ru, Array value) { for(int i = rl; i <= ru; i++) @@ -370,7 +367,8 @@ else if(s instanceof Long) else if(s instanceof Integer) return (Integer) s; else - throw new NotImplementedException("not supported parsing: " + s + " of class: " + s.getClass().getSimpleName()); + throw new NotImplementedException( + "not supported parsing: " + s + " of class: " + s.getClass().getSimpleName()); } public static int parseHashInt(String s) { @@ -435,6 +433,26 @@ public boolean possiblyContainsNaN() { return false; } + @Override + protected int addValRecodeMap(Map map, int id, int i) { + Integer val = Integer.valueOf(getInt(i)); + Integer v = map.putIfAbsent(val, id); + if(v == null) + id++; + return id; + } + + @Override + public void setM(Map map, AMapToData m, int i) { + m.set(i, map.get(Integer.valueOf(getInt(i))).intValue() - 1); + } + + @Override + public void setM(Map map, int si, AMapToData m, int i) { + final Integer v = Integer.valueOf(getInt(i)); + m.set(i, map.get(v).intValue() - 1); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_size * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java index 3c802d3267c..8fd308951e4 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java @@ -23,10 +23,12 @@ import java.io.DataOutput; import java.io.IOException; import java.util.Arrays; +import java.util.Map; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.util.UtilFunctions; @@ -103,11 +105,6 @@ public void set(int index, double value) { _data[index] = (long) value; } - @Override - public void set(int rl, int ru, Array value) { - set(rl, ru, value, 0); - } - @Override public void setFromOtherType(int rl, int ru, Array value) { for(int i = rl; i <= ru; i++) @@ -432,6 +429,26 @@ public boolean possiblyContainsNaN() { return false; } + @Override + protected int addValRecodeMap(Map map, int id, int i) { + Long val = Long.valueOf(getLong(i)); + Integer v = map.putIfAbsent(val, id); + if(v == null) + id++; + + return id; + } + + @Override + public void setM(Map map, AMapToData m, int i) { + m.set(i, map.get(Long.valueOf(getLong(i))) - 1); + } + + @Override + public void setM(Map map, int si, AMapToData m, int i) { + m.set(i, map.get(Long.valueOf(getLong(i))) - 1); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_size * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java index cb06512874c..7a698dbd72f 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java @@ -68,12 +68,7 @@ public void set(int index, double value) { public void set(int index, String value) { set(index, parseInt(value)); } - - @Override - public void set(int rl, int ru, Array value) { - set(rl, ru, value, 0); - } - + @Override public void setFromOtherType(int rl, int ru, Array value) { final ValueType vt = value.getValueType(); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java index 174007dc2b3..aad83fa9e7e 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/LongArray.java @@ -36,7 +36,6 @@ public class LongArray extends Array { private long[] _data; - private LongArray(int nRow) { this(new long[nRow]); } @@ -70,11 +69,6 @@ public void set(int index, String value) { set(index, parseLong(value)); } - @Override - public void set(int rl, int ru, Array value) { - set(rl, ru, value, 0); - } - @Override public void setFromOtherType(int rl, int ru, Array value) { final ValueType vt = value.getValueType(); @@ -156,7 +150,6 @@ protected static LongArray read(DataInput in, int nRow) throws IOException { return arr; } - @Override public Array clone() { return new LongArray(Arrays.copyOf(_data, _size)); @@ -322,7 +315,7 @@ public static long parseLong(String s) { if(s == null || s.isEmpty()) return 0; try { - Long v = Long.parseLong(s); + Long v = Long.parseLong(s); return v; } catch(NumberFormatException e) { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java index 366d00be886..dd0fca6cdfb 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java @@ -22,11 +22,11 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; -import java.util.HashMap; import java.util.Map; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.util.UtilFunctions; @@ -211,11 +211,6 @@ public void setFromOtherType(int rl, int ru, Array value) { } } - @Override - public void set(int rl, int ru, Array value) { - set(rl, ru, value, 0); - } - @Override public void set(int rl, int ru, Array value, int rlSrc) { if(value instanceof OptionalArray) @@ -473,24 +468,23 @@ public boolean possiblyContainsNaN() { } @Override - protected Map createRecodeMap() { - if(getValueType() == ValueType.BOOLEAN) { - // shortcut for boolean arrays, since we only - // need to encounter the first two false and true values. - Map map = new HashMap<>(); - long id = 1; - for(int i = 0; i < size() && id <= 2; i++) { - T val = get(i); - if(val != null) { - Long v = map.putIfAbsent(val, id); - if(v == null) - id++; - } - } - return map; - } + public void setM(Map map, AMapToData m, int i) { + _a.setM(map, m, i); + } + + @Override + public void setM(Map map, int si, AMapToData m, int i) { + if(_n.get(i)) + _a.setM(map, si, m, i); else - return super.createRecodeMap(); + m.set(i, si); + } + + @Override + protected int addValRecodeMap(Map map, int id, int i) { + if(_n.get(i)) + id = _a.addValRecodeMap(map, id, i); + return id; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java index 04a1a2ee5eb..e0c823ca1b0 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/RaggedArray.java @@ -149,10 +149,6 @@ public void setFromOtherType(int rl, int ru, Array value) { throw new NotImplementedException("Unimplemented method 'setFromOtherType'"); } - @Override - public void set(int rl, int ru, Array value) { - set(rl, ru, value, 0); - } @Override public void set(int rl, int ru, Array value, int rlSrc) { @@ -307,7 +303,7 @@ protected Array changeTypeCharacter(Array retA, int l, int return _a.changeTypeCharacter(retA, l, u); } - @Override + @Override public Array changeTypeWithNulls(ValueType t) { throw new NotImplementedException("Not Implemented ragged array with nulls"); } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index 3fb0a4e1da2..f1ef7943498 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -25,6 +25,8 @@ import java.util.Arrays; import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.common.Types.ValueType; @@ -76,11 +78,6 @@ public void set(int index, double value) { materializedSize = -1; } - @Override - public void set(int rl, int ru, Array value) { - set(rl, ru, value, 0); - } - @Override public void setFromOtherType(int rl, int ru, Array value) { for(int i = rl; i <= ru; i++) { @@ -672,21 +669,20 @@ public final boolean isNotEmpty(int i) { } @Override - protected Map createRecodeMap() { + protected Map createRecodeMap(int estimate, ExecutorService pool) throws InterruptedException, ExecutionException { try { - - Map map = new HashMap<>(); + Map map = new HashMap<>((int) Math.min((long) estimate * 2, size())); for(int i = 0; i < size(); i++) { Object val = get(i); if(val != null) { String[] tmp = ColumnEncoderRecode.splitRecodeMapEntry(val.toString()); - map.put(tmp[0], Long.parseLong(tmp[1])); + map.put(tmp[0], Integer.parseInt(tmp[1])); } } return map; } catch(Exception e) { - return super.createRecodeMap(); + return super.createRecodeMap(estimate, pool); } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java b/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java index 894cd1681a6..de7031c7c01 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java @@ -140,6 +140,7 @@ private Array compressColFinally(int i, Future> f) throws Exception private Array allocateCorrectedType(int i) { final ArrayCompressionStatistics s = stats[i]; final Array a = in.getColumn(i); + if(s.valueType != a.getValueType()) return ArrayFactory.allocate(s.valueType, a.size(), s.containsNull); else @@ -226,11 +227,8 @@ private void logStatistics() { for(int i = 0; i < compressedColumns.length; i++) { if(in.getColumn(i) instanceof ACompressedArray) sb.append(String.format("Col: %3d, %s\n", i, "Column is already compressed")); - else if(stats[i].shouldCompress) - sb.append(String.format("Col: %3d, %s\n", i, stats[i])); else - sb.append(String.format("Col: %3d, No Compress, Type: %s", // - i, in.getColumn(i).getClass().getSimpleName())); + sb.append(String.format("Col: %3d, %s\n", i, stats[i])); } LOG.debug(sb); } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java index df9fd84f779..b2796469ac0 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java @@ -325,7 +325,7 @@ public Iterator call(Tuple2> arg0) throws Exce Iterator iter = arg0._2().iterator(); ArrayList ret = new ArrayList<>(); - long rowID = 1; + int rowID = 1; StringBuilder sb = new StringBuilder(); // handle recode maps @@ -371,7 +371,7 @@ else if(_encoder.containsEncoderForID(colID, ColumnEncoderBin.class)) { else { throw new DMLRuntimeException("Unsupported metadata output for encoder: \n" + _encoder); } - _accMax.add(rowID - 1); + _accMax.add(rowID - 1L); return ret.iterator(); } diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index 18bb1043966..b5e4ae21d3e 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -1413,15 +1413,14 @@ else if( !sparse && denseBlock!=null ) //DENSE /** * Recompute the number of nonZero values in parallel * - * @param k the paralelization degree + * @param k the parallelization degree * @return the number of non zeros */ public long recomputeNonZeros(int k) { - if(sparse && sparseBlock!=null) + // fallback to single thread if k <= 1, small matrix, or sparse. + if(k <= 1 || ((long) rlen * clen < 10000) || (sparse && sparseBlock!=null)) return recomputeNonZeros(); else if(!sparse && denseBlock!=null){ - if((long) rlen * clen < 10000) - return recomputeNonZeros(); final ExecutorService pool = CommonThreadPool.get(k); try { List> f = new ArrayList<>(); @@ -1451,7 +1450,7 @@ else if(!sparse && denseBlock!=null){ } catch(Exception e) { - LOG.warn("Failed Parallel non zero count fallback to singlethread"); + LOG.warn("Failed Parallel non zero count fallback to single thread"); return recomputeNonZeros(); } finally { diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java index 019df7f8470..037e7bea1d7 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java @@ -29,10 +29,8 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.Callable; import org.apache.commons.logging.Log; @@ -41,8 +39,8 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; import org.apache.sysds.runtime.data.DenseBlock; -import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockCSR; +import org.apache.sysds.runtime.data.SparseBlockMCSR; import org.apache.sysds.runtime.data.SparseRowVector; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -62,6 +60,9 @@ public abstract class ColumnEncoder implements Encoder, Comparable _sparseRowsWZeros = null; protected int[] sparseRowPointerOffset = null; // offsets created by bag of words encoders (multiple nnz) + // protected ArrayList _sparseRowsWZeros = null; + + protected boolean containsZeroOut = false; protected long _estMetaSize = 0; protected int _estNumDistincts = 0; protected int _nBuildPartitions = 0; @@ -147,8 +148,7 @@ public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int r protected abstract double[] getCodeCol(CacheBlock in, int startInd, int rowEnd, double[] tmp); protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){ - boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR; - mcsr = false; //force CSR for transformencode + boolean mcsr = out.getSparseBlock() instanceof SparseBlockMCSR; int index = _colID - 1; // Apply loop tiling to exploit CPU caches int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk); @@ -425,20 +425,17 @@ public List> getApplyTasks(CacheBlock in, MatrixBlock out, return new ColumnApplyTask<>(this, in, out, outputCol, startRow, blk); } - public Set getSparseRowsWZeros(){ - if (_sparseRowsWZeros != null) { - return new HashSet<>(_sparseRowsWZeros); - } - else - return null; - } - protected void addSparseRowsWZeros(List sparseRowsWZeros){ synchronized (this){ if(_sparseRowsWZeros == null) _sparseRowsWZeros = new ArrayList<>(); _sparseRowsWZeros.addAll(sparseRowsWZeros); } + + } + + protected boolean containsZeroOut(){ + return containsZeroOut; } protected void setBuildRowBlocksPerColumn(int nPart) { diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java index 25b1a0ce876..badd9e200fb 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java @@ -48,7 +48,7 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder { public static int NUM_SAMPLES_MAP_ESTIMATION = 16000; - private Map _tokenDictionary; // switched from int to long to reuse code from RecodeEncoder + private Map _tokenDictionary; // switched from int to long to reuse code from RecodeEncoder private HashSet _tokenDictionaryPart = null; protected String _seperatorRegex = "\\s+"; // whitespace protected boolean _caseSensitive = false; @@ -74,11 +74,11 @@ public ColumnEncoderBagOfWords(ColumnEncoderBagOfWords enc) { _caseSensitive = enc._caseSensitive; } - public void setTokenDictionary(HashMap dict){ + public void setTokenDictionary(HashMap dict){ _tokenDictionary = dict; } - public Map getTokenDictionary() { + public Map getTokenDictionary() { return _tokenDictionary; } @@ -218,7 +218,7 @@ public void build(CacheBlock in) { if(!token.isEmpty()){ tokenSetPerRow.add(token); if(!_tokenDictionary.containsKey(token)) - _tokenDictionary.put(token, (long) i++); + _tokenDictionary.put(token, i++); } _nnzPerRow[r] = tokenSetPerRow.size(); _nnz += tokenSetPerRow.size(); @@ -297,7 +297,7 @@ protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int int i = 0; for (Map.Entry entry : counter.entrySet()) { String token = entry.getKey(); - columnValuePairs[i] = new Pair((int) (outputCol + _tokenDictionary.getOrDefault(token, 0L) - 1), entry.getValue()); + columnValuePairs[i] = new Pair((int) (outputCol + _tokenDictionary.getOrDefault(token, 0) - 1), entry.getValue()); // if token is not included columnValuePairs[i] is overwritten in the next iteration i += _tokenDictionary.containsKey(token) ? 1 : 0; } @@ -363,7 +363,7 @@ public void allocateMetaData(FrameBlock meta) { public FrameBlock getMetaData(FrameBlock out) { int rowID = 0; StringBuilder sb = new StringBuilder(); - for(Map.Entry e : _tokenDictionary.entrySet()) { + for(Map.Entry e : _tokenDictionary.entrySet()) { out.set(rowID++, _colID - 1, constructRecodeMapEntry(e.getKey(), e.getValue(), sb)); } return out; @@ -382,9 +382,9 @@ public void writeExternal(ObjectOutput out) throws IOException { out.writeInt(_tokenDictionary == null ? 0 : _tokenDictionary.size()); if(_tokenDictionary != null) - for(Map.Entry e : _tokenDictionary.entrySet()) { + for(Map.Entry e : _tokenDictionary.entrySet()) { out.writeUTF((String) e.getKey()); - out.writeLong(e.getValue()); + out.writeInt(e.getValue()); } } @@ -395,7 +395,7 @@ public void readExternal(ObjectInput in) throws IOException { _tokenDictionary = new HashMap<>(size * 4 / 3); for(int j = 0; j < size; j++) { String key = in.readUTF(); - Long value = in.readLong(); + Integer value = in.readInt(); _tokenDictionary.put(key, value); } } @@ -476,11 +476,11 @@ private BowMergePartialBuildTask(ColumnEncoderBagOfWords encoderRecode, HashMap< @Override public Object call() { long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; - Map tokenDictionary = _encoder._tokenDictionary; + Map tokenDictionary = _encoder._tokenDictionary; for(Object tokenSet : _partialMaps.values()){ ( (HashSet) tokenSet).forEach(token -> { if(!tokenDictionary.containsKey(token)) - tokenDictionary.put(token, (long) tokenDictionary.size() + 1); + tokenDictionary.put(token, tokenDictionary.size() + 1); }); } for (long nnzPartial : _encoder._nnzPartials) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java index 74b4737194c..524a745a467 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java @@ -19,6 +19,8 @@ package org.apache.sysds.runtime.transform.encode; +import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex; + import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; @@ -28,7 +30,6 @@ import java.util.Random; import java.util.concurrent.Callable; -import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex; import org.apache.commons.lang3.tuple.MutableTriple; import org.apache.sysds.api.DMLScript; import org.apache.sysds.lops.Lop; @@ -36,7 +37,6 @@ import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; -import org.apache.sysds.runtime.frame.data.columns.StringArray; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.utils.stats.TransformStatistics; @@ -60,6 +60,10 @@ public class ColumnEncoderBin extends ColumnEncoder { private double _colMins = -1f; private double _colMaxs = -1f; + protected boolean containsNull = false; + + protected boolean checkedForNull = false; + public ColumnEncoderBin() { super(-1); } @@ -131,6 +135,15 @@ else if(_binMethod == BinMethod.EQUI_HEIGHT_APPROX){ computeEqualHeightBins(vals, false); } + if(in instanceof FrameBlock){ + final Array c = ((FrameBlock )in).getColumn(_colID - 1); + containsNull = c.containsNull(); + checkedForNull = true; + } + else { + checkedForNull = true; + } + if(DMLScript.STATISTICS) TransformStatistics.incBinningBuildTime(System.nanoTime()-t0); } @@ -188,7 +201,7 @@ protected final void getCodeColFrame(FrameBlock in, int startInd, int endInd, do final Array c = in.getColumn(_colID - 1); final double mi = _binMins[0]; final double mx = _binMaxs[_binMaxs.length-1]; - if(!(c instanceof StringArray) && !c.containsNull()) + if(!containsNull && checkedForNull) for(int i = startInd; i < endInd; i++) codes[i - startInd] = getCodeIndex(c.getAsDouble(i), mi, mx); else @@ -209,15 +222,24 @@ else if(_binMethod == BinMethod.EQUI_WIDTH) return getCodeIndexEQHeight(inVal); } - private final double getEqWidth(double inVal, double min, double max) { + protected final double getEqWidth(double inVal, double min, double max) { if(max == min) return 1; - if(_numBin <= 0) - throw new RuntimeException("Invalid num bins"); - final int code = (int)(Math.ceil((inVal - min) / (max - min) * _numBin) ); + return getEqWidthUnsafe(inVal, min, max); + } + + protected final int getEqWidthUnsafe(double inVal){ + final double min = _binMins[0]; + final double max = _binMaxs[_binMaxs.length - 1]; + return getEqWidthUnsafe(inVal, min, max); + } + + protected final int getEqWidthUnsafe(double inVal, double min, double max){ + final int code = (int)(Math.ceil((inVal - min) / (max - min) * _numBin)); return code > _numBin ? _numBin : code < 1 ? 1 : code; } + private final double getCodeIndexEQHeight(double inVal){ if(_binMaxs.length <= 10) return getCodeIndexEQHeightSmall(inVal); @@ -253,9 +275,17 @@ protected TransformType getTransformType() { private static double[] getMinMaxOfCol(CacheBlock in, int colID, int startRow, int blockSize) { // derive bin boundaries from min/max per column + final int end = getEndIndex(in.getNumRows(), startRow, blockSize); + if(in instanceof FrameBlock){ + FrameBlock inf = (FrameBlock) in; + if(startRow == 0 && blockSize == -1) + return inf.getColumn(colID -1).minMax(); + else + return inf.getColumn(colID - 1).minMax(startRow, end); + } + double min = Double.POSITIVE_INFINITY; double max = Double.NEGATIVE_INFINITY; - final int end = getEndIndex(in.getNumRows(), startRow, blockSize); for(int i = startRow; i < end; i++) { final double inVal = in.getDoubleNaN(i, colID - 1); if(!Double.isNaN(inVal)){ @@ -274,17 +304,12 @@ private static double[] prepareDataForEqualHeightBins(CacheBlock in, int colI private static double[] extractDoubleColumn(CacheBlock in, int colID, int startRow, int blockSize) { int endRow = getEndIndex(in.getNumRows(), startRow, blockSize); - double[] vals = new double[endRow - startRow]; final int cid = colID -1; + double[] vals = new double[endRow - startRow]; if(in instanceof FrameBlock) { // FrameBlock optimization Array a = ((FrameBlock) in).getColumn(cid); - for(int i = startRow; i < endRow; i++) { - double inVal = a.getAsNaNDouble(i); - if(Double.isNaN(inVal)) - continue; - vals[i - startRow] = inVal; - } + return a.extractDouble(vals, startRow, endRow); } else { for(int i = startRow; i < endRow; i++) { diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java index 536b387a1da..ce3008802a3 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java @@ -27,10 +27,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.Objects; import java.util.concurrent.Callable; -import java.util.stream.Collectors; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; @@ -425,13 +423,11 @@ public void shiftCol(int columnOffset) { _columnEncoders.forEach(e -> e.shiftCol(columnOffset)); } - @Override - public Set getSparseRowsWZeros(){ - return _columnEncoders.stream().map(ColumnEncoder::getSparseRowsWZeros).flatMap(l -> { - if(l == null) - return null; - return l.stream(); - }).collect(Collectors.toSet()); + protected boolean containsZeroOut(){ + for(int i = 0; i < _columnEncoders.size(); i++) + if(_columnEncoders.get(i).containsZeroOut()) + return true; + return false; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java index fd6e3410bf1..616a6a7ce8b 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java @@ -24,15 +24,14 @@ import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; -import java.util.ArrayList; import java.util.List; import java.util.Objects; import org.apache.sysds.api.DMLScript; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; -import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockCSR; +import org.apache.sysds.runtime.data.SparseBlockMCSR; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.DependencyTask; @@ -115,18 +114,15 @@ protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int throw new DMLRuntimeException( "ColumnEncoderDummycode called with: " + in.getClass().getSimpleName() + " and not MatrixBlock"); } - boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR; - mcsr = false; // force CSR for transformencode - ArrayList sparseRowsWZeros = null; + boolean mcsr = out.getSparseBlock() instanceof SparseBlockMCSR; + // ArrayList sparseRowsWZeros = null; int index = _colID - 1; for(int r = rowStart; r < getEndIndex(in.getNumRows(), rowStart, blk); r++) { int indexWithOffset = sparseRowPointerOffset != null ? sparseRowPointerOffset[r] - 1 + index : index; if(mcsr) { double val = out.getSparseBlock().get(r).values()[indexWithOffset]; if(Double.isNaN(val)) { - if(sparseRowsWZeros == null) - sparseRowsWZeros = new ArrayList<>(); - sparseRowsWZeros.add(r); + containsZeroOut = true; out.getSparseBlock().get(r).values()[indexWithOffset] = 0; continue; } @@ -139,9 +135,7 @@ protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int int rptr[] = csrblock.rowPointers(); double val = csrblock.values()[rptr[r] + indexWithOffset]; if(Double.isNaN(val)) { - if(sparseRowsWZeros == null) - sparseRowsWZeros = new ArrayList<>(); - sparseRowsWZeros.add(r); + containsZeroOut = true; csrblock.values()[rptr[r] + indexWithOffset] = 0; // test continue; } @@ -151,9 +145,6 @@ protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int csrblock.values()[rptr[r] + indexWithOffset] = 1; } } - if(sparseRowsWZeros != null) { - addSparseRowsWZeros(sparseRowsWZeros); - } } protected void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){ diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java index 00c65097567..400b7f64ffc 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java @@ -25,6 +25,7 @@ import java.util.List; import org.apache.sysds.api.DMLScript; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; @@ -67,7 +68,7 @@ protected TransformType getTransformType() { @Override protected double getCode(CacheBlock in, int row) { if(in instanceof FrameBlock){ - Array a = ((FrameBlock)in).getColumn(_colID -1); + Array a = ((FrameBlock)in).getColumn(_colID - 1); return getCode(a, row); } else{ // default @@ -80,16 +81,24 @@ protected double getCode(CacheBlock in, int row) { } protected double getCode(Array a, int row){ - return Math.abs(a.hashDouble(row) % _K + 1); + return Math.abs(a.hashDouble(row)) % _K + 1; + } + + protected static double getCode(Array a, int k , int row){ + return Math.abs(a.hashDouble(row)) % k ; } protected double[] getCodeCol(CacheBlock in, int startInd, int endInd, double[] tmp) { final int endLength = endInd - startInd; final double[] codes = tmp != null && tmp.length == endLength ? tmp : new double[endLength]; - if( in instanceof FrameBlock) { + if(in instanceof FrameBlock) { Array a = ((FrameBlock) in).getColumn(_colID-1); - for(int i = startInd; i < endInd; i++) - codes[i - startInd] = getCode(a, i); + for(int i = startInd; i < endInd; i++){ + double code = getCode(a, i); + if(code <= 0) + throw new DMLRuntimeException("Bad Code"); + codes[i - startInd] = code; + } } else {// default for(int i = startInd; i < endInd; i++) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java index 411e650aa4f..bee71d99eee 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java @@ -21,13 +21,13 @@ import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex; -import java.util.ArrayList; import java.util.List; import org.apache.sysds.api.DMLScript; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockCSR; +import org.apache.sysds.runtime.data.SparseBlockMCSR; import org.apache.sysds.runtime.data.SparseRowVector; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -82,41 +82,48 @@ protected double[] getCodeCol(CacheBlock in, int startInd, int endInd, double } protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){ - //Set sparseRowsWZeros = null; - ArrayList sparseRowsWZeros = null; - boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR; - mcsr = false; //force CSR for transformencode - int index = _colID - 1; - // Apply loop tiling to exploit CPU caches - int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk); - double[] codes = getCodeCol(in, rowStart, rowEnd, null); - int B = 32; //tile size - for(int i = rowStart; i < rowEnd; i+=B) { - int lim = Math.min(i+B, rowEnd); - for (int ii=i; ii(); - sparseRowsWZeros.add(ii); - } - int indexWithOffset = sparseRowPointerOffset != null ? sparseRowPointerOffset[ii] - 1 + index : index; - if (mcsr) { - SparseRowVector row = (SparseRowVector) out.getSparseBlock().get(ii); - row.values()[indexWithOffset] = v; - row.indexes()[indexWithOffset] = outputCol; - } - else { //csr - // Manually fill the column-indexes and values array - SparseBlockCSR csrblock = (SparseBlockCSR)out.getSparseBlock(); - int rptr[] = csrblock.rowPointers(); - csrblock.indexes()[rptr[ii]+indexWithOffset] = outputCol; - csrblock.values()[rptr[ii]+indexWithOffset] = codes[ii-rowStart]; - } - } + final SparseBlock sb = out.getSparseBlock(); + final boolean mcsr = sb instanceof SparseBlockMCSR; + final int index = _colID - 1; + final int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk); + final int bs = 32; + double[] tmp = null; + for(int i = rowStart; i < rowEnd; i+= bs) { + int end = Math.min(i + bs , rowEnd); + tmp = getCodeCol(in, i, end,tmp); + if(mcsr) + applySparseBlockMCSR(in, (SparseBlockMCSR) sb, index, outputCol, i, end, tmp); + else + applySparseBlockCSR(in, (SparseBlockCSR) sb, index, outputCol, i, end, tmp); + + } + } + + private void applySparseBlockMCSR(CacheBlock in, final SparseBlockMCSR sb, final int index, + final int outputCol, int rl, int ru, double[] tmpCodes) { + for(int i = rl; i < ru; i ++) { + final double v = tmpCodes[i - rl]; + SparseRowVector row = (SparseRowVector) sb.get(i); + row.indexes()[index] = outputCol; + if(v == 0) + containsZeroOut = true; + else + row.values()[index] = v; } - if(sparseRowsWZeros != null){ - addSparseRowsWZeros(sparseRowsWZeros); + } + + private void applySparseBlockCSR(CacheBlock in, final SparseBlockCSR sb, final int index, final int outputCol, + int rl, int ru, double[] tmpCodes) { + final int[] rptr = sb.rowPointers(); + final int[] idx = sb.indexes(); + final double[] val = sb.values(); + for(int i = rl; i < ru; i++) { + final double v = tmpCodes[i - rl]; + idx[rptr[i] + index] = outputCol; + if(v == 0) + containsZeroOut = true; + else + val[rptr[i] + index] = v; } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java index 059c1f94589..e784086427d 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java @@ -47,7 +47,7 @@ public class ColumnEncoderRecode extends ColumnEncoder { public static boolean SORT_RECODE_MAP = false; // recode maps and custom map for partial recode maps - private Map _rcdMap; + private Map _rcdMap; private HashSet _rcdMapPart = null; public ColumnEncoderRecode(int colID) { @@ -59,7 +59,7 @@ public ColumnEncoderRecode() { this(-1); } - protected ColumnEncoderRecode(int colID, HashMap rcdMap) { + protected ColumnEncoderRecode(int colID, HashMap rcdMap) { super(colID); _rcdMap = rcdMap; } @@ -71,12 +71,12 @@ protected ColumnEncoderRecode(int colID, HashMap rcdMap) { * @param code is code for token * @return the concatenation of token and code with delimiter in between */ - public static String constructRecodeMapEntry(String token, Long code) { + public static String constructRecodeMapEntry(String token, Integer code) { StringBuilder sb = new StringBuilder(token.length() + 16); return constructRecodeMapEntry(token, code, sb); } - public static String constructRecodeMapEntry(Object token, Long code, StringBuilder sb) { + public static String constructRecodeMapEntry(Object token, Integer code, StringBuilder sb) { sb.setLength(0); // reset reused string builder return sb.append(token).append(Lop.DATATYPE_PREFIX).append(code.longValue()).toString(); } @@ -94,7 +94,7 @@ public static String[] splitRecodeMapEntry(String value) { return new String[] {value.substring(0, pos), value.substring(pos + 1)}; } - public Map getCPRecodeMaps() { + public Map getCPRecodeMaps() { return _rcdMap; } @@ -106,7 +106,7 @@ public void sortCPRecodeMaps() { sortCPRecodeMaps(_rcdMap); } - private static void sortCPRecodeMaps(Map map) { + private static void sortCPRecodeMaps(Map map) { Object[] keys = map.keySet().toArray(new Object[0]); Arrays.sort(keys); map.clear(); @@ -114,7 +114,7 @@ private static void sortCPRecodeMaps(Map map) { putCode(map, key); } - private static void makeRcdMap(CacheBlock in, Map map, int colID, int startRow, int blk) { + private static void makeRcdMap(CacheBlock in, Map map, int colID, int startRow, int blk) { for(int row = startRow; row < getEndIndex(in.getNumRows(), startRow, blk); row++){ String key = in.getString(row, colID - 1); if(key != null && !key.isEmpty() && !map.containsKey(key)) @@ -126,7 +126,7 @@ private static void makeRcdMap(CacheBlock in, Map map, int colI } private long lookupRCDMap(Object key) { - return _rcdMap.getOrDefault(key, -1L); + return _rcdMap.getOrDefault(key, -1); } public void computeMapSizeEstimate(CacheBlock in, int[] sampleIndices) { @@ -203,8 +203,8 @@ public Callable getPartialMergeBuildTask(HashMap ret) { * @param map column map * @param key key for the new entry */ - protected static void putCode(Map map, Object key) { - map.put(key, (long) (map.size() + 1)); + protected static void putCode(Map map, Object key) { + map.put(key, (map.size() + 1)); } protected double getCode(CacheBlock in, int r){ @@ -270,10 +270,10 @@ public void mergeAt(ColumnEncoder other) { assert other._colID == _colID; // merge together overlapping columns ColumnEncoderRecode otherRec = (ColumnEncoderRecode) other; - Map otherMap = otherRec._rcdMap; + Map otherMap = otherRec._rcdMap; if(otherMap != null) { // for each column, add all non present recode values - for(Map.Entry entry : otherMap.entrySet()) { + for(Map.Entry entry : otherMap.entrySet()) { if(lookupRCDMap(entry.getKey()) == -1) { // key does not yet exist putCode(_rcdMap, entry.getKey()); @@ -305,7 +305,7 @@ public FrameBlock getMetaData(FrameBlock meta) { // create compact meta data representation StringBuilder sb = new StringBuilder(); // for reuse int rowID = 0; - for(Entry e : _rcdMap.entrySet()) { + for(Entry e : _rcdMap.entrySet()) { meta.set(rowID++, _colID - 1, // 1-based constructRecodeMapEntry(e.getKey(), e.getValue(), sb)); } @@ -331,9 +331,9 @@ public void writeExternal(ObjectOutput out) throws IOException { super.writeExternal(out); out.writeInt(_rcdMap.size()); - for(Entry e : _rcdMap.entrySet()) { + for(Entry e : _rcdMap.entrySet()) { out.writeUTF(e.getKey().toString()); - out.writeLong(e.getValue()); + out.writeInt(e.getValue()); } } @@ -343,7 +343,7 @@ public void readExternal(ObjectInput in) throws IOException { int size = in.readInt(); for(int j = 0; j < size; j++) { String key = in.readUTF(); - Long value = in.readLong(); + Integer value = in.readInt(); _rcdMap.put(key, value); } } @@ -363,7 +363,7 @@ public int hashCode() { return Objects.hash(_rcdMap); } - public Map getRcdMap() { + public Map getRcdMap() { return _rcdMap; } @@ -374,7 +374,12 @@ public String toString() { sb.append(": "); sb.append(_colID); sb.append(" --- map: "); - sb.append(_rcdMap); + if(_rcdMap.size() < 1000){ + sb.append(_rcdMap); + } + else{ + sb.append("Map to big to print but size is : " + _rcdMap.size()); + } return sb.toString(); } @@ -425,7 +430,7 @@ protected RecodePartialBuildTask(CacheBlock input, int colID, int startRow, @Override public Object call() throws Exception { long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; - HashMap partialMap = new HashMap<>(); + HashMap partialMap = new HashMap<>(); makeRcdMap(_input, partialMap, _colID, _startRow, _blockSize); synchronized(_partialMaps) { _partialMaps.put(_startRow, partialMap); @@ -455,7 +460,7 @@ private RecodeMergePartialBuildTask(ColumnEncoderRecode encoderRecode, HashMap rcdMap = _encoder.getRcdMap(); + Map rcdMap = _encoder.getRcdMap(); _partialMaps.forEach((start_row, map) -> { ((HashMap) map).forEach((k, v) -> { if(!rcdMap.containsKey(k)) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java index 76f1c12a7d3..a4a3fa862bd 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java @@ -35,7 +35,7 @@ public class ColumnEncoderWordEmbedding extends ColumnEncoder { private MatrixBlock _wordEmbeddings; - private Map _rcdMap; + private Map _rcdMap; private HashMap _embMap; public ColumnEncoderWordEmbedding() { @@ -45,8 +45,8 @@ public ColumnEncoderWordEmbedding() { } @SuppressWarnings("unused") - private long lookupRCDMap(Object key) { - return _rcdMap.getOrDefault(key, -1L); + private Integer lookupRCDMap(Object key) { + return _rcdMap.getOrDefault(key, -1); } //domain size is equal to the number columns of the embeddings column thats equal to length of an embedding vector @@ -58,6 +58,7 @@ public int getDomainSize(){ public int getNrDistinctEmbeddings(){ return _wordEmbeddings.getNumRows(); } + protected ColumnEncoderWordEmbedding(int colID) { super(colID); } @@ -138,9 +139,9 @@ public void writeExternal(ObjectOutput out) throws IOException { super.writeExternal(out); out.writeInt(_rcdMap.size()); - for(Map.Entry e : _rcdMap.entrySet()) { + for(Map.Entry e : _rcdMap.entrySet()) { out.writeUTF(e.getKey().toString()); - out.writeLong(e.getValue()); + out.writeInt(e.getValue()); } _wordEmbeddings.write(out); } @@ -151,7 +152,7 @@ public void readExternal(ObjectInput in) throws IOException { int size = in.readInt(); for(int j = 0; j < size; j++) { String key = in.readUTF(); - Long value = in.readLong(); + Integer value = in.readInt(); _rcdMap.put(key, value); } _wordEmbeddings.readExternal(in); diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 6506c6f9f43..7b4698d5bfe 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -23,7 +23,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; @@ -32,8 +31,6 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; -import org.apache.sysds.conf.ConfigurationManager; -import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; @@ -48,17 +45,25 @@ import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; +import org.apache.sysds.runtime.compress.estim.sample.SampleEstimatorFactory; +import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.ACompressedArray; import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.frame.data.columns.DDCArray; +import org.apache.sysds.runtime.frame.data.compress.ArrayCompressionStatistics; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin.BinMethod; import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.sysds.utils.stats.Timing; public class CompressedEncode { protected static final Log LOG = LogFactory.getLog(CompressedEncode.class.getName()); + /** Row parallelization threshold for parallel creation of AMapToData for compression */ + public static int ROW_PARALLELIZATION_THRESHOLD = 10000; + /** The encoding scheme plan */ private final MultiColumnEncoder enc; /** The Input FrameBlock */ @@ -66,56 +71,65 @@ public class CompressedEncode { /** The thread count of the instruction */ private final int k; + /** the Executor pool for parallel tasks of this encoder. */ + private final ExecutorService pool; + + private final boolean inputContainsCompressed; + private CompressedEncode(MultiColumnEncoder enc, FrameBlock in, int k) { this.enc = enc; this.in = in; this.k = k; + this.pool = k > 1 && CommonThreadPool.useParallelismOnThread() ? CommonThreadPool.get(k) : null; + this.inputContainsCompressed = containsCompressed(in); } - public static MatrixBlock encode(MultiColumnEncoder enc, FrameBlock in, int k) - throws InterruptedException, ExecutionException { + private boolean containsCompressed(FrameBlock in) { + for(Array c : in.getColumns()) + if(c instanceof ACompressedArray) + return true; + return false; + } + + public static MatrixBlock encode(MultiColumnEncoder enc, FrameBlock in, int k) throws Exception { return new CompressedEncode(enc, in, k).apply(); } - private MatrixBlock apply() throws InterruptedException, ExecutionException { - final List encoders = enc.getColumnEncoders(); - final List groups = isParallel() ? multiThread(encoders) : singleThread(encoders); - final int cols = shiftGroups(groups); - final MatrixBlock mb = new CompressedMatrixBlock(in.getNumRows(), cols, -1, false, groups); - mb.recomputeNonZeros(); - logging(mb); - return mb; + private MatrixBlock apply() throws Exception { + try { + final List encoders = enc.getColumnEncoders(); + final List groups = isParallel() ? multiThread(encoders) : singleThread(encoders); + final int cols = shiftGroups(groups); + final MatrixBlock mb = new CompressedMatrixBlock(in.getNumRows(), cols, -1, false, groups); + mb.recomputeNonZeros(k); + logging(mb); + return mb; + } + finally { + if(pool != null) + pool.shutdown(); + } } private boolean isParallel() { - return k > 1 && enc.getEncoders().size() > 1; + return pool != null; } - private List singleThread(List encoders) { + private List singleThread(List encoders) throws Exception { List groups = new ArrayList<>(encoders.size()); for(ColumnEncoderComposite c : encoders) groups.add(encode(c)); return groups; } - private List multiThread(List encoders) - throws InterruptedException, ExecutionException { - - final ExecutorService pool = CommonThreadPool.get(k); - try { - List tasks = new ArrayList<>(encoders.size()); - - for(ColumnEncoderComposite c : encoders) - tasks.add(new EncodeTask(c)); - - List groups = new ArrayList<>(encoders.size()); - for(Future t : pool.invokeAll(tasks)) - groups.add(t.get()); - return groups; - } - finally { - pool.shutdown(); - } + private List multiThread(List encoders) throws Exception { + final List> tasks = new ArrayList<>(encoders.size()); + for(ColumnEncoderComposite c : encoders) + tasks.add(pool.submit(() -> encode(c))); + final List groups = new ArrayList<>(encoders.size()); + for(Future t : tasks) + groups.add(t.get()); + return groups; } /** @@ -133,7 +147,16 @@ private int shiftGroups(List groups) { return cols; } - private AColGroup encode(ColumnEncoderComposite c) { + private AColGroup encode(ColumnEncoderComposite c) throws Exception { + final Timing t = new Timing(); + AColGroup g = executeEncode(c); + if(LOG.isDebugEnabled()) + LOG.debug(String.format("Encode: columns: %4d estimateDistinct: %6d distinct: %6d size: %6d time: %10f", + c._colID, c._estNumDistincts, g.getNumValues(), g.estimateInMemorySize(), t.stop())); + return g; + } + + private AColGroup executeEncode(ColumnEncoderComposite c) throws Exception { if(c.isRecodeToDummy()) return recodeToDummy(c); else if(c.isRecode()) @@ -153,13 +176,15 @@ else if(c.isHashToDummy()) } @SuppressWarnings("unchecked") - private AColGroup recodeToDummy(ColumnEncoderComposite c) { + private AColGroup recodeToDummy(ColumnEncoderComposite c) throws Exception { int colId = c._colID; - Array a = in.getColumn(colId - 1); + Array a = (Array) in.getColumn(colId - 1); boolean containsNull = a.containsNull(); - Map map = a.getRecodeMap(); + estimateRCDMapSize(c); + Map map = a.getRecodeMap(c._estNumDistincts, CommonThreadPool.get(k)); + List r = c.getEncoders(); - r.set(0, new ColumnEncoderRecode(colId, (HashMap) map)); + r.set(0, new ColumnEncoderRecode(colId, (HashMap) map)); int domain = map.size(); if(containsNull && domain == 0) return new ColGroupEmpty(ColIndexFactory.create(1)); @@ -169,99 +194,149 @@ private AColGroup recodeToDummy(ColumnEncoderComposite c) { ADictionary d = new IdentityDictionary(colIndexes.size(), containsNull); AMapToData m = createMappingAMapToData(a, map, containsNull); + return ColGroupDDC.create(colIndexes, d, m, null); } - private AColGroup bin(ColumnEncoderComposite c) { + private AColGroup bin(ColumnEncoderComposite c) throws InterruptedException, ExecutionException { final int colId = c._colID; final Array a = in.getColumn(colId - 1); - final boolean containsNull = a.containsNull(); final List r = c.getEncoders(); final ColumnEncoderBin b = (ColumnEncoderBin) r.get(0); b.build(in); + final boolean containsNull = b.containsNull; final IColIndex colIndexes = ColIndexFactory.create(1); ADictionary d = createIncrementingVector(b._numBin, containsNull); - AMapToData m = binEncode(a, b, containsNull); + final AMapToData m; + m = binEncode(a, b, containsNull); AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); return ret; } - private AMapToData binEncode(Array a, ColumnEncoderBin b, boolean containsNull) { - AMapToData m = MapToFactory.create(a.size(), b._numBin + (containsNull ? 1 : 0)); - if(containsNull) { - for(int i = 0; i < a.size(); i++) { - final double v = a.getAsNaNDouble(i); - try { - - if(Double.isNaN(v)) - m.set(i, b._numBin); - else { - int idx = (int) b.getCodeIndex(v) - 1; - if(idx < 0) - idx = 0; - m.set(i, idx); - } - } - catch(Exception e) { - - m.set(i, (int) b.getCodeIndex(v - 0.00001) - 1); - } + private AMapToData binEncode(Array a, ColumnEncoderBin b, boolean nulls) + throws InterruptedException, ExecutionException { + final AMapToData m = MapToFactory.create(a.size(), b._numBin + (nulls ? 1 : 0)); + + if(!nulls && b.getBinMethod() == BinMethod.EQUI_WIDTH) { + final double min = b.getBinMins()[0]; + final double max = b.getBinMaxs()[b.getNumBin() - 1]; + if(Util.eq(max, min)) { + m.fill(0); + return m; } + if(b._numBin <= 0) + throw new RuntimeException("Invalid num bins"); + } + + final int rlen = a.size(); + if(k > 1 && rlen > ROW_PARALLELIZATION_THRESHOLD) { + BinEncodeParallel(a, b, nulls, m, rlen); } else { + if(nulls) + binEncodeWithNulls(a, b, m, 0, a.size()); + else + binEncodeNoNull(a, b, m, 0, a.size()); - for(int i = 0; i < a.size(); i++) { - try { - - int idx = (int) b.getCodeIndex(a.getAsDouble(i)) - 1; - if(idx < 0) - idx = 0; - // throw new RuntimeException(a.getAsDouble(i) + " is invalid value for " + b + "\n" + idx); - m.set(i, idx); - } - catch(Exception e) { - - int idx = (int) b.getCodeIndex(a.getAsDouble(i) - 0.00001) - 1; - m.set(i, idx); - } - } } return m; } - private MatrixBlockDictionary createIncrementingVector(int nVals, boolean NaN) { + private void BinEncodeParallel(Array a, ColumnEncoderBin b, boolean nulls, final AMapToData m, final int rlen) + throws InterruptedException, ExecutionException { + final List> tasks = new ArrayList<>(); + final int blockSize = Math.max(ROW_PARALLELIZATION_THRESHOLD / 2, rlen + k / k); + final ExecutorService pool = CommonThreadPool.get(k); + try { + + for(int i = 0; i < rlen; i += blockSize) { + final int start = i; + final int end = Math.min(rlen, i + blockSize); + tasks.add(pool.submit(() -> { + if(nulls) + binEncodeWithNulls(a, b, m, start, end); + else + binEncodeNoNull(a, b, m, start, end); + })); + } + for(Future t : tasks) + t.get(); + } + finally { + pool.shutdown(); + } + } + + private void binEncodeWithNulls(Array a, ColumnEncoderBin b, AMapToData m, int l, int u) { + for(int i = l; i < u; i++) { + final double v = a.getAsNaNDouble(i); + if(Double.isNaN(v)) + m.set(i, b._numBin); + else { + int idx = (int) b.getCodeIndex(v) - 1; + if(idx < 0) + idx = 0; + m.set(i, idx); + } + + } + } + + private final void binEncodeNoNull(Array a, ColumnEncoderBin b, AMapToData m, int l, int u) { + if(b.getBinMethod() == BinMethod.EQUI_WIDTH) + binEncodeNoNullEqWidth(a, b, m, l, u); + else + binEncodeNoNullGeneric(a, b, m, l, u); + } + + private final void binEncodeNoNullEqWidth(Array a, ColumnEncoderBin b, AMapToData m, int l, int u) { + final double min = b.getBinMins()[0]; + final double max = b.getBinMaxs()[b.getNumBin() - 1]; + for(int i = l; i < u; i++) { + m.set(i, b.getEqWidthUnsafe(a.getAsDouble(i), min, max) - 1); + } + } + private final void binEncodeNoNullGeneric(Array a, ColumnEncoderBin b, AMapToData m, int l, int u) { + final double min = b.getBinMins()[0]; + final double max = b.getBinMaxs()[b.getNumBin() - 1]; + for(int i = l; i < u; i++) { + m.set(i, (int) b.getCodeIndex(a.getAsDouble(i), min, max) - 1); + } + } + + private MatrixBlockDictionary createIncrementingVector(int nVals, boolean NaN) { MatrixBlock bins = new MatrixBlock(nVals + (NaN ? 1 : 0), 1, false); for(int i = 0; i < nVals; i++) bins.set(i, 0, i + 1); if(NaN) bins.set(nVals, 0, Double.NaN); - return MatrixBlockDictionary.create(bins); - } - private AColGroup binToDummy(ColumnEncoderComposite c) { + private AColGroup binToDummy(ColumnEncoderComposite c) throws InterruptedException, ExecutionException { final int colId = c._colID; final Array a = in.getColumn(colId - 1); - final boolean containsNull = a.containsNull(); final List r = c.getEncoders(); final ColumnEncoderBin b = (ColumnEncoderBin) r.get(0); - b.build(in); + b.build(in); // build first since we figure out if it contains null here. + final boolean containsNull = b.containsNull; IColIndex colIndexes = ColIndexFactory.create(0, b._numBin); ADictionary d = new IdentityDictionary(colIndexes.size(), containsNull); - AMapToData m = binEncode(a, b, containsNull); + final AMapToData m; + m = binEncode(a, b, containsNull); AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); return ret; } @SuppressWarnings("unchecked") - private AColGroup recode(ColumnEncoderComposite c) { + private AColGroup recode(ColumnEncoderComposite c) throws Exception { int colId = c._colID; - Array a = in.getColumn(colId - 1); - Map map = a.getRecodeMap(); + Array a = (Array) in.getColumn(colId - 1); + estimateRCDMapSize(c); + Map map = a.getRecodeMap(c._estNumDistincts, CommonThreadPool.get(k)); boolean containsNull = a.containsNull(); int domain = map.size(); @@ -280,107 +355,154 @@ private AColGroup recode(ColumnEncoderComposite c) { AMapToData m = createMappingAMapToData(a, map, containsNull); List r = c.getEncoders(); - r.set(0, new ColumnEncoderRecode(colId, (HashMap) map)); + r.set(0, new ColumnEncoderRecode(colId, (HashMap) map)); return ColGroupDDC.create(colIndexes, d, m, null); } @SuppressWarnings("unchecked") - private AColGroup passThrough(ColumnEncoderComposite c) { - // TODO optimize to not construct full map but only some of it until aborting compression. - IColIndex colIndexes = ColIndexFactory.create(1); - int colId = c._colID; - Array a = in.getColumn(colId - 1); - if(a instanceof ACompressedArray){ - switch(a.getFrameArrayType()) { - case DDC: - DDCArray aDDC = (DDCArray) a; - Array dict = aDDC.getDict(); - double[] vals = new double[dict.size()]; - for(int i = 0; i < dict.size(); i++) { - vals[i] = dict.getAsDouble(i); - } - ADictionary d = Dictionary.create(vals); - - return ColGroupDDC.create(colIndexes, d, aDDC.getMap(), null); - default: - throw new NotImplementedException(); - } - } - boolean containsNull = a.containsNull(); - HashMap map = (HashMap) a.getRecodeMap(); - final int blockSz = ConfigurationManager.getDMLConfig().getIntValue(DMLConfig.DEFAULT_BLOCK_SIZE); - if(map.size() >= blockSz) { + private AColGroup passThrough(ColumnEncoderComposite c) throws Exception { + + final IColIndex colIndexes = ColIndexFactory.create(1); + final int colId = c._colID; + final Array a = (Array) in.getColumn(colId - 1); + if(a instanceof ACompressedArray) + return passThroughCompressed(colIndexes, a); + else + return passThroughNormal(c, colIndexes, a); + } + + private AColGroup passThroughNormal(ColumnEncoderComposite c, final IColIndex colIndexes, final Array a) + throws InterruptedException, ExecutionException, Exception { + // Take a small sample + ArrayCompressionStatistics stats = !inputContainsCompressed ? // + a.statistics(Math.min(1000, a.size())) : null; + + if(a.getValueType() != ValueType.BOOLEAN // if not booleans + && (stats == null || !stats.shouldCompress || stats.valueType != a.getValueType())) { + // stats.valueType; double[] vals = (double[]) a.changeType(ValueType.FP64).get(); + MatrixBlock col = new MatrixBlock(a.size(), 1, vals); - col.recomputeNonZeros(); - // lets make it an uncompressed column group. + col.recomputeNonZeros(1); return ColGroupUncompressed.create(colIndexes, col, false); } else { + boolean containsNull = a.containsNull(); + estimateRCDMapSize(c); + Map map = a.getRecodeMap(c._estNumDistincts, CommonThreadPool.get(k)); double[] vals = new double[map.size() + (containsNull ? 1 : 0)]; if(containsNull) vals[map.size()] = Double.NaN; ValueType t = a.getValueType(); - map.forEach((k, v) -> vals[v.intValue()-1] = UtilFunctions.objectToDouble(t, k)); + map.forEach((k, v) -> vals[v.intValue() - 1] = UtilFunctions.objectToDouble(t, k)); ADictionary d = Dictionary.create(vals); AMapToData m = createMappingAMapToData(a, map, containsNull); return ColGroupDDC.create(colIndexes, d, m, null); } + } + + private AColGroup passThroughCompressed(final IColIndex colIndexes, final Array a) { + // only DDC possible currently. + DDCArray aDDC = (DDCArray) a; + Array dict = aDDC.getDict(); + double[] vals = new double[dict.size()]; + if(a.containsNull()) + for(int i = 0; i < dict.size(); i++) + vals[i] = dict.getAsNaNDouble(i); + else + for(int i = 0; i < dict.size(); i++) + vals[i] = dict.getAsDouble(i); + ADictionary d = Dictionary.create(vals); + + return ColGroupDDC.create(colIndexes, d, aDDC.getMap(), null); } - private AMapToData createMappingAMapToData(Array a, Map map, boolean containsNull) { + private AMapToData createMappingAMapToData(Array a, Map map, boolean containsNull) + throws Exception { + final int si = map.size(); + final int nRow = in.getNumRows(); + if(!containsNull && a instanceof DDCArray) + return ((DDCArray) a).getMap(); + + final AMapToData m = MapToFactory.create(nRow, si + (containsNull ? 1 : 0)); + + if(k > 1 && nRow > ROW_PARALLELIZATION_THRESHOLD) + return CreateMappingParallel(a, map, containsNull, si, nRow, m); + else + return createMappingSingleThread(a, map, containsNull, si, nRow, m); + } + + private AMapToData CreateMappingParallel(Array a, Map map, boolean containsNull, final int si, + final int nRow, final AMapToData m) throws InterruptedException, ExecutionException { + final int blkz = Math.max(ROW_PARALLELIZATION_THRESHOLD / 2, (nRow + k) / k); + + List> tasks = new ArrayList<>(); + // make a thread local pool. + // this pool is independent of the potential generally shared pool + ExecutorService pool = CommonThreadPool.get(k); try { + for(int i = 0; i < nRow; i += blkz) { + final int start = i; + final int end = Math.min(nRow, i + blkz); - final int si = map.size(); - AMapToData m = MapToFactory.create(in.getNumRows(), si + (containsNull ? 1 : 0)); - Array.ArrayIterator it = a.getIterator(); - if(containsNull) { - - while(it.hasNext()) { - Object v = it.next(); - try{ - if(v != null) - m.set(it.getIndex(), map.get(v).intValue() -1); - else - m.set(it.getIndex(), si); - } - catch(Exception e){ - throw new RuntimeException("failed on " + v +" " + a.getValueType(), e); - } - } - } - else { - while(it.hasNext()) { - Object v = it.next(); - m.set(it.getIndex(), map.get(v).intValue() -1); - } + tasks.add(pool.submit(() -> { + if(containsNull) + return createMappingAMapToDataWithNull(a, map, si, m, start, end); + else + return createMappingAMapToDataNoNull(a, map, m, start, end); + + })); } + + for(Future t : tasks) + t.get(); return m; } - catch(Exception e) { - throw new RuntimeException("failed constructing map: " + map, e); + finally { + pool.shutdown(); } + + } + + private AMapToData createMappingSingleThread(Array a, Map map, boolean containsNull, final int si, + final int nRow, final AMapToData m) { + if(containsNull) + return createMappingAMapToDataWithNull(a, map, si, m, 0, nRow); + else + return createMappingAMapToDataNoNull(a, map, m, 0, nRow); + } + + private static AMapToData createMappingAMapToDataNoNull(Array a, Map map, AMapToData m, int start, + int end) { + for(int i = start; i < end; i++) + a.setM(map, m, i); + return m; + } + + private static AMapToData createMappingAMapToDataWithNull(Array a, Map map, int si, AMapToData m, + int start, int end) { + for(int i = start; i < end; i++) + a.setM(map, si, m, i); + return m; } private AMapToData createHashMappingAMapToData(Array a, int k, boolean nulls) { AMapToData m = MapToFactory.create(a.size(), k + (nulls ? 1 : 0)); if(nulls) { for(int i = 0; i < a.size(); i++) { - double h = Math.abs(a.hashDouble(i)); - if(Double.isNaN(h)) { + double h = Math.abs(a.hashDouble(i)) % k; + if(Double.isNaN(h)) m.set(i, k); - } - else { - m.set(i, (int) h % k); - } + else + m.set(i, (int) h); } } else { for(int i = 0; i < a.size(); i++) { - double h = Math.abs(a.hashDouble(i)); - m.set(i, (int) h % k); + double h = Math.abs(a.hashDouble(i)) % k; + m.set(i, (int) h); } } return m; @@ -423,17 +545,38 @@ private AColGroup hashToDummy(ColumnEncoderComposite c) { return ColGroupDDC.create(colIndexes, d, m, null); } - private class EncodeTask implements Callable { - - ColumnEncoderComposite c; - - protected EncodeTask(ColumnEncoderComposite c) { - this.c = c; + @SuppressWarnings("unchecked") + private void estimateRCDMapSize(ColumnEncoderComposite c) { + if(c._estNumDistincts != 0) + return; + Array col = (Array) in.getColumn(c._colID - 1); + if(col instanceof DDCArray) { + DDCArray ddcCol = (DDCArray) col; + c._estNumDistincts = ddcCol.getDict().size(); + return; } - - public AColGroup call() throws Exception { - return encode(c); + final int nRow = in.getNumRows(); + if(nRow <= 1024) { + c._estNumDistincts = 10; + return; + } + // 2% sample or max 3000 + int sampleSize = Math.max(Math.min(in.getNumRows() / 50, 4096 * 2), 1024); + // Find the frequencies of distinct values in the sample + Map distinctFreq = new HashMap<>(); + for(int sind = 0; sind < sampleSize; sind++) { + T key = col.getInternal(sind); + if(distinctFreq.containsKey(key)) + distinctFreq.put(key, distinctFreq.get(key) + 1); + else + distinctFreq.put(key, 1); } + + // Estimate total #distincts using Hass and Stokes estimator + int[] freq = distinctFreq.values().stream().mapToInt(v -> v).toArray(); + int estDistCount = SampleEstimatorFactory.distinctCount(freq, nRow, sampleSize, + SampleEstimatorFactory.EstimationType.HassAndStokes); + c._estNumDistincts = estDistCount; } private void logging(MatrixBlock mb) { diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java index 11107b6df6c..60e7ec78051 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java @@ -325,7 +325,7 @@ public void initMetaData(FrameBlock meta) { int colID = _colList[j]; String mvVal = UtilFunctions.unquote(meta.getColumnMetadata(colID - 1).getMvValue()); if(_rcList.contains(colID)) { - Long mvVal2 = meta.getRecodeMap(colID - 1).get(mvVal); + Integer mvVal2 = meta.getRecodeMap(colID - 1).get(mvVal); if(mvVal2 == null) throw new RuntimeException( "Missing recode value for impute value '" + mvVal + "' (colID=" + colID + ")."); diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java index 79c05ca8e72..7cafa0e437a 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java @@ -19,6 +19,8 @@ package org.apache.sysds.runtime.transform.encode; +import static org.apache.sysds.utils.MemoryEstimates.intArrayCost; + import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; @@ -29,8 +31,6 @@ import java.util.HashSet; import java.util.List; import java.util.Map; -import java.util.Objects; -import java.util.Set; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -54,6 +54,7 @@ import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockCSR; +import org.apache.sysds.runtime.data.SparseBlockMCSR; import org.apache.sysds.runtime.data.SparseRowVector; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -64,8 +65,6 @@ import org.apache.sysds.runtime.util.IndexRange; import org.apache.sysds.utils.stats.TransformStatistics; -import static org.apache.sysds.utils.MemoryEstimates.intArrayCost; - public class MultiColumnEncoder implements Encoder { protected static final Log LOG = LogFactory.getLog(MultiColumnEncoder.class.getName()); @@ -365,10 +364,11 @@ public MatrixBlock apply(CacheBlock in, MatrixBlock out, int outputCol, int k // There should be a encoder for every column if(hasLegacyEncoder() && !(in instanceof FrameBlock)) throw new DMLRuntimeException("LegacyEncoders do not support non FrameBlock Inputs"); - int numEncoders = getFromAll(ColumnEncoderComposite.class, ColumnEncoder::getColID).size(); + int numEncoders = getEncoders().size(); + // getFromAll(ColumnEncoderComposite.class, ColumnEncoder::getColID).size(); if(in.getNumColumns() != numEncoders) throw new DMLRuntimeException("Not every column in has a CompositeEncoder. Please make sure every column " - + "has a encoder or slice the input accordingly"); + + "has a encoder or slice the input accordingly: num encoders: " + getEncoders() + " vs columns " + in.getNumColumns()); // TODO smart checks // Block allocation for MT access if(in.getNumRows() == 0) @@ -834,59 +834,51 @@ private static void aggregateNnzPerRow(int start, int blk_len, int numRows, List private void outputMatrixPostProcessing(MatrixBlock output, int k){ long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; - if(output.isInSparseFormat()){ + if(output.isInSparseFormat() && containsZeroOut()){ if (k == 1) outputMatrixPostProcessingSingleThread(output); else outputMatrixPostProcessingParallel(output, k); } - else { - output.recomputeNonZeros(k); - } + output.recomputeNonZeros(k); + if(DMLScript.STATISTICS) TransformStatistics.incOutMatrixPostProcessingTime(System.nanoTime()-t0); } - private void outputMatrixPostProcessingSingleThread(MatrixBlock output){ - Set indexSet = _columnEncoders.stream() - .map(ColumnEncoderComposite::getSparseRowsWZeros).flatMap(l -> { - if(l == null) - return null; - return l.stream(); - }).collect(Collectors.toSet()); - - if(!indexSet.stream().allMatch(Objects::isNull)) { - for(Integer row : indexSet) - output.getSparseBlock().get(row).compact(); + final SparseBlock sb = output.getSparseBlock(); + if(sb instanceof SparseBlockMCSR) { + IntStream.range(0, output.getNumRows()).forEach(row -> { + sb.compact(row); + }); + } + else { + ((SparseBlockCSR) sb).compact(); } - - output.recomputeNonZeros(); } + private boolean containsZeroOut() { + for(ColumnEncoder e : _columnEncoders) + if(e.containsZeroOut()) + return true; + return false; + } private void outputMatrixPostProcessingParallel(MatrixBlock output, int k) { ExecutorService pool = CommonThreadPool.get(k); try { - // Collect the row indices that need compaction - Set indexSet = pool.submit(() -> _columnEncoders.stream().parallel() - .map(ColumnEncoderComposite::getSparseRowsWZeros).flatMap(l -> { - if(l == null) - return null; - return l.stream(); - }).collect(Collectors.toSet())).get(); - - // Check if the set is empty - boolean emptySet = pool.submit(() -> indexSet.stream().parallel().allMatch(Objects::isNull)).get(); - - // Concurrently compact the rows - if(emptySet) { + final SparseBlock sb = output.getSparseBlock(); + if(sb instanceof SparseBlockMCSR) { pool.submit(() -> { - indexSet.stream().parallel().forEach(row -> { - output.getSparseBlock().get(row).compact(); + IntStream.range(0, output.getNumRows()).parallel().forEach(row -> { + sb.compact(row); }); }).get(); } + else { + ((SparseBlockCSR) sb).compact(); + } } catch(Exception ex) { throw new DMLRuntimeException(ex); @@ -894,8 +886,6 @@ private void outputMatrixPostProcessingParallel(MatrixBlock output, int k) { finally { pool.shutdown(); } - - output.recomputeNonZeros(); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java index bad708d691f..7eb0c2bd722 100644 --- a/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java +++ b/src/main/java/org/apache/sysds/runtime/util/CommonThreadPool.java @@ -19,9 +19,9 @@ package org.apache.sysds.runtime.util; +import java.util.ArrayList; import java.util.Collection; import java.util.List; -import java.util.Map.Entry; import java.util.concurrent.Callable; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; @@ -71,7 +71,7 @@ public class CommonThreadPool implements ExecutorService { private final ExecutorService _pool; /** Local variable indicating if there was a thread that was not main, and requested a thread pool */ - private static boolean incorrectPoolUse = false; + public static boolean incorrectPoolUse = false; /** * Constructor of the threadPool. This is intended not to be used except for tests. Please use the static @@ -103,11 +103,16 @@ public static ExecutorService get() { * @return The executor with specified parallelism */ public synchronized static ExecutorService get(int k) { + if(k <= 1) { + LOG.warn("Invalid to create thread pool with <= one thread returning single thread executor", + new RuntimeException()); + return new SameThreadExecutorService(); + } + final Thread thisThread = Thread.currentThread(); final String threadName = thisThread.getName(); // Contains main, because we name our test threads TestRunner_main final boolean mainThread = threadName.contains("main"); - if(size == k && mainThread) return shared; // use the default thread pool if main thread and max parallelism. else if(mainThread || threadName.contains("PARFOR")) { @@ -125,13 +130,20 @@ else if(mainThread || threadName.contains("PARFOR")) { } else { // If we are neither a main thread or parfor thread, allocate a new thread pool - if(!incorrectPoolUse){ - LOG.warn("An instruction allocated it's own thread pool indicating that some task is not properly reusing the threads."); + if(!incorrectPoolUse) { + if(threadName.contains("test")) + LOG.error("Thread from test is not correctly using pools, please modify thread name to contain 'main'", + new RuntimeException()); + else + LOG.warn( + "An instruction allocated it's own thread pool indicating that some task is not properly reusing the threads.", + new RuntimeException()); incorrectPoolUse = true; } + return Executors.newFixedThreadPool(k); - } + } } /** @@ -164,7 +176,8 @@ public static void invokeAndShutdown(ExecutorService pool, Collection e : shared2.entrySet()) - for(Runnable a : e.getValue()._pool.shutdownNow()) - a.wait(); - } - catch(Exception e1) { - throw new RuntimeException(e1); - } - finally { - shared2 = null; - } + for(Long e : shared2.keySet()) + shutdownPool(e); + shared2 = null; } } @@ -202,18 +207,15 @@ public synchronized static void shutdownAsyncPools() { * @param thread The thread given that could or could not have allocated a thread pool itself. */ public synchronized static void shutdownAsyncPools(Thread thread) { - if(shared2 != null) { - try { - final CommonThreadPool p = shared2.get(thread.getId()); - if(p != null) { - for(Runnable a : p._pool.shutdownNow()) - a.wait(); - shared2.remove(thread.getId()); - } - } - catch(InterruptedException e) { - throw new RuntimeException(e); - } + if(shared2 != null) + shutdownPool(thread.getId()); + } + + private static void shutdownPool(long id) { + final CommonThreadPool p = shared2.get(id); + if(p != null) { + p._pool.shutdownNow(); + shared2.remove(id); } } @@ -224,7 +226,7 @@ public synchronized static void shutdownAsyncPools(Thread thread) { * * @return If there is a thread pool allocated for this thread. */ - public static synchronized boolean generalCached() { + public synchronized static boolean generalCached() { return shared2 != null && shared2.get(Thread.currentThread().getId()) != null; } @@ -326,4 +328,149 @@ else if(name.contains("test")) else return false; } + + public static class SameThreadExecutorService implements ExecutorService { + + private SameThreadExecutorService() { + // private constructor. + } + + @Override + public void execute(Runnable command) { + command.run(); + } + + @Override + public void shutdown() { + // nothing + } + + @Override + public List shutdownNow() { + return new ArrayList<>(); + } + + @Override + public boolean isShutdown() { + return false; + } + + @Override + public boolean isTerminated() { + return false; + + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return true; + } + + @Override + public Future submit(Callable task) { + return new NonFuture<>(task); + } + + @Override + public Future submit(Runnable task, T result) { + return new NonFuture<>(() -> { + task.run(); + return result; + }); + } + + @Override + public Future submit(Runnable task) { + return new NonFuture<>(() -> { + task.run(); + return null; + }); + } + + @Override + public List> invokeAll(Collection> tasks) throws InterruptedException { + List> ret = new ArrayList<>(); + for(Callable t : tasks) + ret.add(new NonFuture<>(t)); + return ret; + } + + @Override + public List> invokeAll(Collection> tasks, long timeout, TimeUnit unit) + throws InterruptedException { + return invokeAll(tasks); + } + + @Override + public T invokeAny(Collection> tasks) throws InterruptedException, ExecutionException { + Exception e = null; + for(Callable t : tasks) { + try { + T r = t.call(); + return r; + } + catch(Exception ee) { + e = ee; + } + + } + throw new ExecutionException("failed", e); + } + + @Override + public T invokeAny(Collection> tasks, long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + Exception e = null; + for(Callable t : tasks) { + try { + T r = t.call(); + return r; + } + catch(Exception ee) { + e = ee; + } + + } + throw new ExecutionException("failed", e); + } + + private static class NonFuture implements Future { + + V r; + + protected NonFuture(Callable c) { + try { + r = c.call(); + } + catch(Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return true; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return true; + } + + @Override + public V get() throws InterruptedException, ExecutionException { + return r; + } + + @Override + public V get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + return r; + } + } + } } diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index bbb29deb4a3..26dff3a12bb 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -2336,6 +2336,15 @@ public static FrameBlock generateRandomFrameBlock(int rows, ValueType[] schema, return generateRandomFrameBlock(rows, schema, random); } + /** + * Generate a random FrameBlock + * + * @param rows The number of rows in the block + * @param schema The schema (also determines the number of columns) + * @param seed The seed for the random generators + * @param nullChance The percentage of values that are null. (0 is no nulls, while 1 is all null) + * @return A new FrameBlock + */ public static FrameBlock generateRandomFrameBlock(int rows, ValueType[] schema, long seed, double nullChance){ Random random = (seed == -1) ? TestUtils.random : new Random(seed); diff --git a/src/test/java/org/apache/sysds/test/component/frame/FrameApplySchema.java b/src/test/java/org/apache/sysds/test/component/frame/FrameApplySchema.java index a22ba8b3094..16b9ff98525 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/FrameApplySchema.java +++ b/src/test/java/org/apache/sysds/test/component/frame/FrameApplySchema.java @@ -53,9 +53,9 @@ public class FrameApplySchema { protected static final Log LOG = LogFactory.getLog(FrameApplySchema.class.getName()); - static { - FrameLibApplySchema.PAR_ROW_THRESHOLD = 10; - } + // static { + // FrameLibApplySchema.PAR_ROW_THRESHOLD = 10; + // } @Test public void testApplySchemaStringToBoolean() { diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java index 642b3b1b84f..73d04f32435 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java @@ -1204,14 +1204,14 @@ public void changeTypeNullsFromStringToBoolean() { public void mappingCache() { Array a = new StringArray(new String[] {"1", null}); assertEquals(null, a.getCache()); - a.setCache(new SoftReference>(null)); + a.setCache(new SoftReference>(null)); assertTrue(null != a.getCache()); - a.setCache(new SoftReference>(new HashMap<>())); + a.setCache(new SoftReference>(new HashMap<>())); assertTrue(null != a.getCache()); - Map hm = a.getCache().get(); - hm.put("1", 0L); - hm.put(null, 2L); - assertEquals(Long.valueOf(0L), a.getCache().get().get("1")); + Map hm = a.getCache().get(); + hm.put("1", 0); + hm.put(null, 2); + assertEquals(Integer.valueOf(0), a.getCache().get().get("1")); } @Test @@ -1256,7 +1256,7 @@ public void DDCCompressAbort() { } } - @Test(expected = DMLRuntimeException.class) + @Test public void DDCCompressInvalid() { FrameBlock.debug = true; // should be fine in general to set while testing Array b = ArrayFactory.create(new boolean[4]); @@ -1727,7 +1727,7 @@ public void testMinMaxDDC2() { @Test public void createRecodeMap() { Array a = ArrayFactory.create(new int[] {1, 1, 1, 1, 3, 3, 1, 2}); - Map m = a.getRecodeMap(); + Map m = a.getRecodeMap(); assertTrue(3 == m.size()); assertTrue(1L == m.get(1)); assertTrue(2L == m.get(3)); @@ -1738,7 +1738,7 @@ public void createRecodeMap() { @Test public void createRecodeMapWithNull() { Array a = ArrayFactory.create(new Integer[] {1, 1, 1, null, 3, 3, 1, 2}); - Map m = a.getRecodeMap(); + Map m = a.getRecodeMap(); assertTrue(3 == m.size()); assertTrue(1L == m.get(1)); assertTrue(2L == m.get(3)); @@ -1749,7 +1749,7 @@ public void createRecodeMapWithNull() { @Test public void createRecodeMapBoolean() { Array a = ArrayFactory.create(new boolean[] {true, true, false, false, true}); - Map m = a.getRecodeMap(); + Map m = a.getRecodeMap(); assertTrue(2 == m.size()); assertTrue(1 == m.get(true)); assertTrue(2 == m.get(false)); @@ -1758,7 +1758,7 @@ public void createRecodeMapBoolean() { @Test public void createRecodeMapBoolean2() { Array a = ArrayFactory.create(new boolean[] {false, true, false, false, true}); - Map m = a.getRecodeMap(); + Map m = a.getRecodeMap(); assertTrue(2 == m.size()); assertTrue(2 == m.get(true)); assertTrue(1 == m.get(false)); @@ -1767,7 +1767,7 @@ public void createRecodeMapBoolean2() { @Test public void createRecodeMapBoolean3() { Array a = ArrayFactory.create(new boolean[] {true, true}); - Map m = a.getRecodeMap(); + Map m = a.getRecodeMap(); assertTrue(1 == m.size()); assertTrue(1 == m.get(true)); assertTrue(null == m.get(false)); @@ -1776,7 +1776,7 @@ public void createRecodeMapBoolean3() { @Test public void createRecodeMapBooleanWithNull() { Array a = ArrayFactory.create(new Boolean[] {true, null, true}); - Map m = a.getRecodeMap(); + Map m = a.getRecodeMap(); assertTrue(1 == m.size()); assertTrue(1 == m.get(true)); assertTrue(null == m.get(false)); @@ -1785,8 +1785,8 @@ public void createRecodeMapBooleanWithNull() { @Test public void createRecodeMapCached() { Array a = ArrayFactory.create(new int[] {1, 1, 1, 1, 3, 3, 1, 2}); - Map m = a.getRecodeMap(); - Map m2 = a.getRecodeMap(); + Map m = a.getRecodeMap(); + Map m2 = a.getRecodeMap(); assertEquals(m, m2); } @@ -2166,9 +2166,9 @@ public void testArrayFactorySet() { Array dict = ((DDCArray) a).getDict(); a = ((DDCArray) a).nullDict(); - Array r = ArrayFactory.set(null, a, 50, 99, 150); + Array r = ArrayFactory.set(null, a.slice(50, 100), 50, 99, 150); ArrayFactory.set(r, a, 0, 49, 150); - ArrayFactory.set(r, a, 50, 149, 150); + ArrayFactory.set(r, a.slice(50, 150), 50, 149, 150); DDCArray rd = (DDCArray) r; diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java index 47744f71cac..4643c34a2f9 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/FrameArrayTests.java @@ -80,6 +80,7 @@ public class FrameArrayTests { @Parameters public static Collection data() { ArrayList tests = new ArrayList<>(); + FrameBlock.debug = true; try { int[] seeds = new int[] {1, 6, 123, 232}; @@ -467,7 +468,7 @@ public void getStatistics() { ArrayCompressionStatistics s = (a.size() < 1000) ? // a.statistics(a.size()) : a.statistics(1000); assertNotNull(s); // not ever allowed to be null!! - if(a.getValueType() != ValueType.BOOLEAN || a.containsNull()) + if(a.getValueType() != ValueType.BOOLEAN || a.containsNull()) assertTrue(s.toString(), s.compressedSizeEstimate <= s.originalSize); else // not true if we do some other compression scheme. but in general Boolean makes it bigger. assertTrue(s.toString(), s.compressedSizeEstimate >= s.originalSize); @@ -532,11 +533,17 @@ public void getFrameArrayType() { @Test public void testSliceStart() { - int size = a.size(); - if(size <= 1) - return; - Array aa = a.slice(0, a.size() - 2); - compare(aa, a, 0); + try { + int size = a.size(); + if(size <= 1) + return; + Array aa = a.slice(0, a.size() - 2); + compare(aa, a, 0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } } @Test @@ -2146,11 +2153,14 @@ public void NotEquals() { @Test public void createRecodeMap() { if(a.size() < 500) { - Map m = a.getRecodeMap(); + Map m = a.getRecodeMap(); for(int i = 0; i < a.size(); i++) { - Object v = a.get(i); + Object v = a.getInternal(i); if(v != null) { - assertTrue(m.containsKey(v)); + if(!m.containsKey(v)) { + fail("For Array Class:" + a.getClass().getSimpleName() + " Recode map " + m + " did not contain key " + + v); + } } } } @@ -2159,7 +2169,6 @@ public void createRecodeMap() { @Test public void extractDouble() { try { - double[] ret = new double[a.size()]; a.extractDouble(ret, 0, a.size()); for(int i = 0; i < a.size(); i++) { diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/RecodeMapTest.java b/src/test/java/org/apache/sysds/test/component/frame/array/RecodeMapTest.java new file mode 100644 index 00000000000..5cac2f22201 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/array/RecodeMapTest.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame.array; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.List; +import java.util.Map; + +import org.apache.log4j.Level; +import org.apache.log4j.Logger; +import org.apache.log4j.spi.LoggingEvent; +import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; +import org.apache.sysds.runtime.util.CommonThreadPool; +import org.apache.sysds.test.LoggingUtils; +import org.apache.sysds.test.LoggingUtils.TestAppender; +import org.junit.Test; + +public class RecodeMapTest { + + @Test + public void createRecodeMapLoggingDebug() throws Exception { + final TestAppender appender = LoggingUtils.overwrite(); + try { + Logger.getLogger(Array.class).setLevel(Level.DEBUG); + Array a = ArrayFactory.create(FrameArrayTests.generateRandomStringNUnique(100, 324, 10)); + + Map rcm = a.getRecodeMap(10); + assertTrue(rcm.size() == 10); + final List log = LoggingUtils.reinsert(appender); + assertTrue(log.size() >= 1); + } + finally { + LoggingUtils.reinsert(appender); + } + + } + + @Test + public void createRecodeMapParallel() throws Exception { + final TestAppender appender = LoggingUtils.overwrite(); + int tmp = Array.ROW_PARALLELIZATION_THRESHOLD; + try { + Array.ROW_PARALLELIZATION_THRESHOLD = 10; + Logger.getLogger(Array.class).setLevel(Level.DEBUG); + Array a = ArrayFactory.create(FrameArrayTests.generateRandomStringNUnique(1000, 324, 10)); + + Map rcm = a.getRecodeMap(10, CommonThreadPool.get(10)); + assertTrue(rcm.size() == 10); + final List log = LoggingUtils.reinsert(appender); + assertTrue(log.size() >= 1); + } + finally { + LoggingUtils.reinsert(appender); + Array.ROW_PARALLELIZATION_THRESHOLD = tmp; + } + + } + + @Test + public void createRecodeMapParallel2() throws Exception { + final TestAppender appender = LoggingUtils.overwrite(); + int tmp = Array.ROW_PARALLELIZATION_THRESHOLD; + try { + Array.ROW_PARALLELIZATION_THRESHOLD = 10; + Logger.getLogger(Array.class).setLevel(Level.DEBUG); + Array a = ArrayFactory.create(FrameArrayTests.generateRandomStringNUnique(1000, 324, 500)); + + Map rcm = a.getRecodeMap(10, CommonThreadPool.get(10)); + Map rcm2 = a.getRecodeMap(10, null); + assertTrue(Math.abs(rcm.size() - 500) < 100); + + assertTrue(rcm.size() == rcm2.size()); + + for(String k : rcm.keySet()){ + assertEquals(rcm.get(k), rcm2.get(k)); + } + final List log = LoggingUtils.reinsert(appender); + assertTrue(log.size() >= 1); + } + finally { + LoggingUtils.reinsert(appender); + Array.ROW_PARALLELIZATION_THRESHOLD = tmp; + } + + } +} diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestLogger.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestLogger.java index 1ea9ca344e2..9576a4d8052 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestLogger.java +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestLogger.java @@ -80,19 +80,26 @@ public void test(String spec) { // FrameBlock outNormalMD = encoderNormal.getMetaData(null); final List log = LoggingUtils.reinsert(appender); - assertTrue(log.get(3).getMessage().toString().contains("Compression ratio")); + + boolean containsCompressionRationMessage = false; + for(LoggingEvent l : log) { + containsCompressionRationMessage |= l.getMessage().toString().contains("Compression ratio"); + } + + assertTrue(containsCompressionRationMessage); + TestUtils.compareMatrices(outNormal, outCompressed, 0, "Not Equal after apply"); MultiColumnEncoder ec = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), encoderCompressed.getMetaData(null)); - + MatrixBlock outMeta1 = ec.apply(data, 1); TestUtils.compareMatrices(outNormal, outMeta1, 0, "Not Equal after apply"); MultiColumnEncoder ec2 = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), encoderNormal.getMetaData(null)); - + MatrixBlock outMeta12 = ec2.apply(data, 1); TestUtils.compareMatrices(outNormal, outMeta12, 0, "Not Equal after apply2"); } diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java index af81216412c..8094a59f48b 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java @@ -24,14 +24,19 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.logging.Level; +import java.util.logging.Logger; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.lib.FrameLibCompress; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.transform.encode.CompressedEncode; import org.apache.sysds.runtime.transform.encode.EncoderFactory; import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; +import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.test.TestUtils; import org.junit.Test; import org.junit.runner.RunWith; @@ -46,6 +51,9 @@ public class TransformCompressedTestMultiCol { private final int k; public TransformCompressedTestMultiCol(FrameBlock data, int k) { + Thread.currentThread().setName("test_transformThread"); + Logger.getLogger(CommonThreadPool.class.getName()).setLevel(Level.OFF); + CompressedEncode.ROW_PARALLELIZATION_THRESHOLD = 10; this.data = data; this.k = k; } @@ -77,6 +85,12 @@ public static Collection data() { TestUtils.generateRandomFrameBlock(5, kPlusCols, 322), TestUtils.generateRandomFrameBlock(1020, kPlusCols, 322), + FrameLibCompress.compress(TestUtils.generateRandomFrameBlock(1030, new ValueType[] { + ValueType.UINT4, ValueType.BOOLEAN, ValueType.UINT4}, 231, 0.0), 2), + FrameLibCompress.compress(TestUtils.generateRandomFrameBlock(1030, new ValueType[] { + ValueType.UINT4, ValueType.BOOLEAN, ValueType.UINT4}, 231, 0.5), 2), + + }; blocks[2].ensureAllocatedColumns(20); diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java index 3a5d05919e9..d5b8a094154 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java @@ -23,14 +23,19 @@ import java.util.ArrayList; import java.util.Collection; +import java.util.logging.Level; +import java.util.logging.Logger; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.lib.FrameLibCompress; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.transform.encode.CompressedEncode; import org.apache.sysds.runtime.transform.encode.EncoderFactory; import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; +import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.test.TestUtils; import org.junit.Test; import org.junit.runner.RunWith; @@ -45,6 +50,9 @@ public class TransformCompressedTestSingleCol { private final int k; public TransformCompressedTestSingleCol(FrameBlock data, int k) { + Thread.currentThread().setName("test_transformThread"); + Logger.getLogger(CommonThreadPool.class.getName()).setLevel(Level.OFF); + CompressedEncode.ROW_PARALLELIZATION_THRESHOLD = 10; this.data = data; this.k = k; } @@ -59,6 +67,12 @@ public static Collection data() { TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231, 0.2), TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231, 1.0), TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231, 1.0), + + FrameLibCompress + .compress(TestUtils.generateRandomFrameBlock(103, new ValueType[] {ValueType.UINT4}, 231, 1.0), 2), + FrameLibCompress + .compress(TestUtils.generateRandomFrameBlock(235, new ValueType[] {ValueType.UINT4}, 23132, 0.0), 2), + // Above block size of number of unique elements TestUtils.generateRandomFrameBlock(1200, new ValueType[] {ValueType.FP32}, 231, 0.1),}; @@ -146,7 +160,7 @@ public void test(String spec) { MultiColumnEncoder encoderNormal = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), meta); MatrixBlock outNormal = encoderNormal.encode(data, k); - + MultiColumnEncoder encoderCompressed = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), meta); MatrixBlock outCompressed = encoderCompressed.encode(data, k, true); @@ -158,14 +172,14 @@ public void test(String spec) { MultiColumnEncoder ec = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), encoderCompressed.getMetaData(null)); - + MatrixBlock outMeta1 = ec.apply(data, k); TestUtils.compareMatrices(outNormal, outMeta1, 0, "Not Equal after apply"); MultiColumnEncoder ec2 = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), encoderNormal.getMetaData(null)); - + MatrixBlock outMeta12 = ec2.apply(data, k); TestUtils.compareMatrices(outNormal, outMeta12, 0, "Not Equal after apply2"); diff --git a/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java b/src/test/java/org/apache/sysds/test/component/misc/ThreadPoolTests.java similarity index 69% rename from src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java rename to src/test/java/org/apache/sysds/test/component/misc/ThreadPoolTests.java index 5004d413abf..4d97ef7acc4 100644 --- a/src/test/java/org/apache/sysds/test/component/misc/ThreadPool.java +++ b/src/test/java/org/apache/sysds/test/component/misc/ThreadPoolTests.java @@ -24,6 +24,7 @@ import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -34,19 +35,36 @@ import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; +import org.apache.log4j.spi.LoggingEvent; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.util.CommonThreadPool; +import org.apache.sysds.test.LoggingUtils; +import org.apache.sysds.test.LoggingUtils.TestAppender; import org.apache.sysds.utils.stats.InfrastructureAnalyzer; import org.junit.Test; -public class ThreadPool { - protected static final Log LOG = LogFactory.getLog(ThreadPool.class.getName()); +public class ThreadPoolTests { + protected static final Log LOG = LogFactory.getLog(ThreadPoolTests.class.getName()); + + Thread.UncaughtExceptionHandler h = new Thread.UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread th, Throwable ex) { + ex.printStackTrace(); + ; + fail(th.getName() + " " + ex.getMessage()); + throw new RuntimeException(ex); + } + }; @Test public void testGetTheSame() { @@ -255,16 +273,27 @@ public void justWorksShutdownNow() throws InterruptedException, ExecutionExcepti @Test public void justWorksShutdownNowNotMain() throws InterruptedException, ExecutionException { - for(int j = 0; j < 2; j++) { - - for(int i = 4; i < 16; i++) { - ExecutorService p = CommonThreadPool.get(i); - final Integer l = i; - assertEquals(l, p.submit(() -> l).get()); - p.shutdownNow(); - + Thread t = new Thread(() -> { + for(int j = 0; j < 2; j++) { + + for(int i = 4; i < 16; i++) { + ExecutorService p = CommonThreadPool.get(i); + final Integer l = i; + try { + assertEquals(l, p.submit(() -> l).get()); + } + catch(Exception e) { + + } + finally { + p.shutdown(); + } + + } } - } + }, "somethingOtherThanMM"); + t.start(); + t.join(); } @Test @@ -415,4 +444,210 @@ public void ParallelismThread() throws Exception { t.join(); CommonThreadPool.shutdownAsyncPools(t); } + + @Test + public void ParallelismThread_test() throws Exception { + Thread t = new Thread(() -> { + assertTrue(CommonThreadPool.useParallelismOnThread()); + }, "fdsfasdftestfdsfa"); + t.start(); + t.join(); + CommonThreadPool.shutdownAsyncPools(t); + } + + @Test + public void get1ThreadPool() { + ExecutorService e = CommonThreadPool.get(1); + assertTrue(e instanceof CommonThreadPool.SameThreadExecutorService); + } + + @Test + public void get1ThreadPoolWorks() throws Exception { + final TestAppender appender = LoggingUtils.overwrite(); + ExecutorService e = CommonThreadPool.get(1); + Future f = e.submit(() -> { + return null; + }); + ; + assertTrue(f.cancel(true)); + assertFalse(f.isCancelled()); + assertTrue(f.isDone()); + e.shutdown();// does nothing + assertNull(f.get()); + assertNull(f.get(132, TimeUnit.DAYS)); + + e.execute(() -> { + }); // nothing ... + + assertTrue(e.shutdownNow().isEmpty()); + assertFalse(e.isShutdown()); + assertFalse(e.isTerminated()); + assertTrue(e.awaitTermination(0, null)); + + Runnable t = new Runnable() { + @Override + public void run() { + return; + } + + }; + Future r = e.submit(t, new Object()); + assertTrue(r.isDone()); + Future r2 = e.submit(t); + assertTrue(r2.isDone()); + LoggingUtils.reinsert(appender); + } + + @Test + public void getThreadPoolContainingTests() throws Exception { + CommonThreadPool.incorrectPoolUse = false; + final TestAppender appender = LoggingUtils.overwrite(); + ExecutorService pool = Executors.newFixedThreadPool(2); + try { + + pool.submit(() -> { + Thread.currentThread().setName("BAAAAtest"); + ExecutorService p = CommonThreadPool.get(2); + try { + assertTrue(p instanceof ThreadPoolExecutor); + return null; + } + catch(Exception e) { + throw e; + } + finally { + p.shutdown(); + } + }).get(); + + } + finally { + + pool.shutdown(); + + for(LoggingEvent l : LoggingUtils.reinsert(appender)) { + if(l.getLevel() == Level.ERROR) + return; + } + fail("not correctly logged"); + } + + } + + @Test + public void getThreadPoolContainingNoTests() throws Exception { + CommonThreadPool.incorrectPoolUse = false; + final TestAppender appender = LoggingUtils.overwrite(); + Logger.getLogger(CommonThreadPool.class).setLevel(Level.TRACE); + ExecutorService pool = Executors.newFixedThreadPool(2); + try { + + pool.submit(() -> { + Thread.currentThread().setName("BAAAANoTTTessst"); + ExecutorService p = CommonThreadPool.get(2); + try { + assertTrue(p instanceof ThreadPoolExecutor); + return null; + } + catch(Exception e) { + throw e; + } + finally { + p.shutdown(); + } + }).get(); + + } + finally { + + pool.shutdown(); + + Logger.getLogger(CommonThreadPool.class).setLevel(Level.ERROR); + for(LoggingEvent l : LoggingUtils.reinsert(appender)) { + if(l.getLevel() == Level.WARN) + return; + } + fail("not correctly logged"); + } + + } + + @Test + public void getThreadLocalSharedPoolsTests() throws Exception { + CommonThreadPool.incorrectPoolUse = false; + Thread[] ts = new Thread[10]; + for(int i = 0; i < 10; i++) { + + ts[i] = new Thread(() -> { + ExecutorService pool = CommonThreadPool.get(2); + try { + assertTrue(pool instanceof CommonThreadPool); + pool.submit(() -> { + try { + Thread.sleep(3000); + } + catch(Exception e) { + throw new RuntimeException(e); + } + }); + } + catch(Exception e) { + throw new RuntimeException(e); + } + finally { + pool.shutdown(); + } + }, "PARFOR_" + i); + + } + + for(Thread t : ts) { + t.setUncaughtExceptionHandler(h); + t.start(); + } + + Thread.sleep(20); + CommonThreadPool.shutdownAsyncPools(); + + for(Thread t : ts) { + t.join(); + } + } + + @Test(expected = RuntimeException.class) + public void get1ThreadPoolException() throws Exception { + ExecutorService pool = CommonThreadPool.get(1); + + pool.submit(() -> { + throw new RuntimeException(); + }).get(); + + } + + @Test + public void generalCached() { + CommonThreadPool.shutdownAsyncPools(); + assertFalse(CommonThreadPool.generalCached()); + + ExecutorService pool = Executors.newFixedThreadPool(2); + + try { + + pool.submit(() -> { + assertFalse(CommonThreadPool.generalCached()); + Thread.currentThread().setName("someThingWith_main"); + ExecutorService e = CommonThreadPool.get(3); + assertTrue(CommonThreadPool.generalCached()); + CommonThreadPool.shutdownAsyncPools(Thread.currentThread()); + assertFalse(CommonThreadPool.generalCached()); + e.shutdown(); + }); + + } + finally { + pool.shutdown(); + } + + } + } diff --git a/src/test/java/org/apache/sysds/test/component/resource/RecompilationTest.java b/src/test/java/org/apache/sysds/test/component/resource/RecompilationTest.java index 178ae2f46ad..ef3af0eb407 100644 --- a/src/test/java/org/apache/sysds/test/component/resource/RecompilationTest.java +++ b/src/test/java/org/apache/sysds/test/component/resource/RecompilationTest.java @@ -236,7 +236,7 @@ private void runTestMM(String fileX, String fileY, long driverMemory, int number // original compilation used for comparison Program expectedProgram = ResourceCompiler.compile(HOME+"mm_test.dml", nvargs); Program recompiledProgram = runTest(precompiledProgram, expectedProgram, driverMemory, numberExecutors, executorMemory); - System.out.println(Explain.explain(recompiledProgram)); + Optional mmInstruction = ((BasicProgramBlock) recompiledProgram.getProgramBlocks().get(0)).getInstructions().stream() .filter(inst -> (Objects.equals(expectedSparkExecType, inst instanceof SPInstruction) && Objects.equals(inst.getOpcode(), expectedOpcode))) .findFirst(); diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynRealDataTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynRealDataTest.java index b816a5dcb8a..e276e45c0b4 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynRealDataTest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinAdasynRealDataTest.java @@ -70,7 +70,7 @@ public void testTitanicAdasynK4() { @Test public void testTitanicAdasynK5() { - runAdasynTest(TITANIC_DATA, TITANIC_TFSPEC, true, 0.797, 5, ExecType.CP); + runAdasynTest(TITANIC_DATA, TITANIC_TFSPEC, true, 0.786, 5, ExecType.CP); } private void runAdasynTest(String data, String tfspec, boolean adasyn, double minAcc, int k, ExecType instType) { diff --git a/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderSerializationTest.java b/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderSerializationTest.java index 2bd1e646978..cac7937b526 100644 --- a/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderSerializationTest.java +++ b/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderSerializationTest.java @@ -136,10 +136,10 @@ else if (type == TransformType.BOW) MultiColumnEncoder encoderIn = EncoderFactory.createEncoder(spec, cnames, frame.getNumColumns(), null); if(type == TransformType.BOW){ List encs = encoderIn.getColumnEncoders(ColumnEncoderBagOfWords.class); - HashMap dict = new HashMap<>(); - dict.put("val1", 1L); - dict.put("val2", 2L); - dict.put("val3", 300L); + HashMap dict = new HashMap<>(); + dict.put("val1", 1); + dict.put("val2", 2); + dict.put("val3", 300); encs.forEach(e -> e.setTokenDictionary(dict)); } MultiColumnEncoder encoderOut; @@ -165,7 +165,7 @@ else if (type == TransformType.BOW) List encsIn = encoderIn.getColumnEncoders(ColumnEncoderBagOfWords.class); List encsOut = encoderOut.getColumnEncoders(ColumnEncoderBagOfWords.class); for (int i = 0; i < encsIn.size(); i++) { - Map eOutDict = encsOut.get(i).getTokenDictionary(); + Map eOutDict = encsOut.get(i).getTokenDictionary(); encsIn.get(i).getTokenDictionary().forEach((k,v) -> { assert v.equals(eOutDict.get(k)); }); From 86069c62ddb723383846be3dc46ecb963ab974f5 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 11:46:44 +0100 Subject: [PATCH 02/81] Perf Transform Encode --- .../org/apache/sysds/performance/Main.java | 6 +- .../sysds/performance/frame/Transform.java | 87 +++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 src/test/java/org/apache/sysds/performance/frame/Transform.java diff --git a/src/test/java/org/apache/sysds/performance/Main.java b/src/test/java/org/apache/sysds/performance/Main.java index a02a0f437fd..d7f85c9b788 100644 --- a/src/test/java/org/apache/sysds/performance/Main.java +++ b/src/test/java/org/apache/sysds/performance/Main.java @@ -24,6 +24,7 @@ import org.apache.sysds.performance.compression.Serialize; import org.apache.sysds.performance.compression.StreamCompress; import org.apache.sysds.performance.compression.TransformPerf; +import org.apache.sysds.performance.frame.Transform; import org.apache.sysds.performance.generators.ConstMatrix; import org.apache.sysds.performance.generators.FrameFile; import org.apache.sysds.performance.generators.FrameTransformFile; @@ -128,9 +129,12 @@ private static void exec(int prog, String[] args) throws Exception { case 1005: ReshapePerf.main(args); break; - case 1006: + case 1006: MatrixBinaryCellPerf.main(args); break; + case 1007: + Transform.main(args); + break; default: break; } diff --git a/src/test/java/org/apache/sysds/performance/frame/Transform.java b/src/test/java/org/apache/sysds/performance/frame/Transform.java new file mode 100644 index 00000000000..39c9d6d2fa5 --- /dev/null +++ b/src/test/java/org/apache/sysds/performance/frame/Transform.java @@ -0,0 +1,87 @@ +package org.apache.sysds.performance.frame; + +import java.util.Arrays; + +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.performance.compression.APerfTest; +import org.apache.sysds.performance.generators.ConstFrame; +import org.apache.sysds.performance.generators.IGenerate; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; +import org.apache.sysds.test.TestUtils; + +public class Transform extends APerfTest { + + private final int k; + private final String spec; + + public Transform(int N, IGenerate gen, int k, String spec) { + super(N, gen); + this.k = k; + this.spec = spec; + FrameBlock in = gen.take(); + System.out.println("Transform Encode Perf: rows: " + in.getNumRows() + " schema:" + Arrays.toString(in.getSchema())); + System.out.println(spec); + } + + public void run() throws Exception { + execute(() -> te(), () -> clear(), "Normal"); + execute(() -> tec(), () -> clear(), "Compressed"); + execute(() -> te(), () -> clear(), "Normal"); + execute(() -> tec(), () -> clear(), "Compressed"); + } + + private void te(){ + FrameBlock in = gen.take(); + MultiColumnEncoder enc = EncoderFactory.createEncoder(spec, in.getNumColumns()); + enc.encode(in, k); + ret.add(null); + } + + private void tec(){ + FrameBlock in = gen.take(); + MultiColumnEncoder enc = EncoderFactory.createEncoder(spec, in.getNumColumns()); + enc.encode(in, k, true); + ret.add(null); + } + + private void clear(){ + clearRDCCache(gen.take()); + } + + @Override + protected String makeResString() { + return ""; + } + + + /** + * Forcefully clear recode cache of underlying arrays + */ + public void clearRDCCache(FrameBlock f){ + for(Array a : f.getColumns()) + a.setCache(null); + } + + + public static void main(String[] args) throws Exception { + for(int i = 1; i < 100; i *= 10){ + + FrameBlock in = TestUtils.generateRandomFrameBlock(100000 * i , new ValueType[]{ValueType.UINT4}, 32); + System.out.println(Arrays.toString(in.getColumnNames())); + ConstFrame gen = new ConstFrame(in); + // passthrough + new Transform(300, gen, 16, "{}").run(); + new Transform(300, gen, 16, "{ids:true, recode:[1]}").run(); + new Transform(300, gen, 16, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}").run(); + new Transform(300, gen, 16, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[1]}").run(); + new Transform(300, gen, 16, "{ids:true, hash:[1], K:10}").run(); + new Transform(300, gen, 16, "{ids:true, hash:[1], K:10, dummycode:[1]}").run(); + } + + System.exit(0); // forcefully stop. + } + +} From 11472b3ad2dba887ac1b6964c1711b5c1ad3e364 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 11:49:47 +0100 Subject: [PATCH 03/81] cleanup --- .../sysds/performance/frame/Transform.java | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/test/java/org/apache/sysds/performance/frame/Transform.java b/src/test/java/org/apache/sysds/performance/frame/Transform.java index 39c9d6d2fa5..dbd87ebf6c6 100644 --- a/src/test/java/org/apache/sysds/performance/frame/Transform.java +++ b/src/test/java/org/apache/sysds/performance/frame/Transform.java @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.sysds.performance.frame; import java.util.Arrays; From 0a70ce4840f631d81e1a3c58374877eaace6b8e9 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 12:36:08 +0100 Subject: [PATCH 04/81] multi column --- .../transform/encode/CompressedEncode.java | 2 +- .../sysds/performance/frame/Transform.java | 77 ++++++++++++++----- 2 files changed, 58 insertions(+), 21 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 7b4698d5bfe..67cf2cee09a 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -98,7 +98,7 @@ public static MatrixBlock encode(MultiColumnEncoder enc, FrameBlock in, int k) t private MatrixBlock apply() throws Exception { try { final List encoders = enc.getColumnEncoders(); - final List groups = isParallel() ? multiThread(encoders) : singleThread(encoders); + final List groups = singleThread(encoders); //isParallel() ? multiThread(encoders) : singleThread(encoders); final int cols = shiftGroups(groups); final MatrixBlock mb = new CompressedMatrixBlock(in.getNumRows(), cols, -1, false, groups); mb.recomputeNonZeros(k); diff --git a/src/test/java/org/apache/sysds/performance/frame/Transform.java b/src/test/java/org/apache/sysds/performance/frame/Transform.java index dbd87ebf6c6..c040eb742b6 100644 --- a/src/test/java/org/apache/sysds/performance/frame/Transform.java +++ b/src/test/java/org/apache/sysds/performance/frame/Transform.java @@ -30,6 +30,7 @@ import org.apache.sysds.runtime.transform.encode.EncoderFactory; import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; import org.apache.sysds.test.TestUtils; +import org.apache.sysds.utils.stats.InfrastructureAnalyzer; public class Transform extends APerfTest { @@ -41,32 +42,33 @@ public Transform(int N, IGenerate gen, int k, String spec) { this.k = k; this.spec = spec; FrameBlock in = gen.take(); - System.out.println("Transform Encode Perf: rows: " + in.getNumRows() + " schema:" + Arrays.toString(in.getSchema())); + System.out + .println("Transform Encode Perf: rows: " + in.getNumRows() + " schema:" + Arrays.toString(in.getSchema())); System.out.println(spec); } public void run() throws Exception { execute(() -> te(), () -> clear(), "Normal"); execute(() -> tec(), () -> clear(), "Compressed"); - execute(() -> te(), () -> clear(), "Normal"); - execute(() -> tec(), () -> clear(), "Compressed"); + // execute(() -> te(), () -> clear(), "Normal"); + // execute(() -> tec(), () -> clear(), "Compressed"); } - private void te(){ + private void te() { FrameBlock in = gen.take(); MultiColumnEncoder enc = EncoderFactory.createEncoder(spec, in.getNumColumns()); enc.encode(in, k); ret.add(null); } - private void tec(){ + private void tec() { FrameBlock in = gen.take(); MultiColumnEncoder enc = EncoderFactory.createEncoder(spec, in.getNumColumns()); enc.encode(in, k, true); ret.add(null); } - private void clear(){ + private void clear() { clearRDCCache(gen.take()); } @@ -75,29 +77,64 @@ protected String makeResString() { return ""; } - - /** + /** * Forcefully clear recode cache of underlying arrays */ - public void clearRDCCache(FrameBlock f){ + public void clearRDCCache(FrameBlock f) { for(Array a : f.getColumns()) a.setCache(null); } - public static void main(String[] args) throws Exception { - for(int i = 1; i < 100; i *= 10){ + int k = InfrastructureAnalyzer.getLocalParallelism(); + for(int i = 1; i < 100; i *= 10) { + + FrameBlock in = TestUtils.generateRandomFrameBlock(100000 * i, new ValueType[] {ValueType.UINT4}, 32); - FrameBlock in = TestUtils.generateRandomFrameBlock(100000 * i , new ValueType[]{ValueType.UINT4}, 32); - System.out.println(Arrays.toString(in.getColumnNames())); ConstFrame gen = new ConstFrame(in); - // passthrough - new Transform(300, gen, 16, "{}").run(); - new Transform(300, gen, 16, "{ids:true, recode:[1]}").run(); - new Transform(300, gen, 16, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}").run(); - new Transform(300, gen, 16, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[1]}").run(); - new Transform(300, gen, 16, "{ids:true, hash:[1], K:10}").run(); - new Transform(300, gen, 16, "{ids:true, hash:[1], K:10, dummycode:[1]}").run(); + // // passthrough + new Transform(300, gen, k, "{}").run(); + new Transform(300, gen, k, "{ids:true, recode:[1]}").run(); + new Transform(300, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}").run(); + new Transform(300, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[1]}").run(); + new Transform(300, gen, k, "{ids:true, hash:[1], K:10}").run(); + new Transform(300, gen, k, "{ids:true, hash:[1], K:10, dummycode:[1]}").run(); + + in = TestUtils.generateRandomFrameBlock( + 100000 * i, new ValueType[] {ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, + ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4}, + 32); + + gen = new ConstFrame(in); + new Transform(300, gen, k, "{}").run(); + new Transform(300, gen, k, "{ids:true, recode:[1,2,3,4,5,6,7,8,9,10]}").run(); + new Transform(300, gen, k, "{ids:true, bin:[" // + + "\n{id:1, method:equi-width, numbins:4}," // + + "\n{id:2, method:equi-width, numbins:4}," // + + "\n{id:3, method:equi-width, numbins:4}," // + + "\n{id:4, method:equi-width, numbins:4}," // + + "\n{id:5, method:equi-width, numbins:4}," // + + "\n{id:6, method:equi-width, numbins:4}," // + + "\n{id:7, method:equi-width, numbins:4}," // + + "\n{id:8, method:equi-width, numbins:4}," // + + "\n{id:9, method:equi-width, numbins:4}," // + + "\n{id:10, method:equi-width, numbins:4}," // + + "]}").run(); + new Transform(300, gen, k, "{ids:true, bin:[" // + + "\n{id:1, method:equi-width, numbins:4}," // + + "\n{id:2, method:equi-width, numbins:4}," // + + "\n{id:3, method:equi-width, numbins:4}," // + + "\n{id:4, method:equi-width, numbins:4}," // + + "\n{id:5, method:equi-width, numbins:4}," // + + "\n{id:6, method:equi-width, numbins:4}," // + + "\n{id:7, method:equi-width, numbins:4}," // + + "\n{id:8, method:equi-width, numbins:4}," // + + "\n{id:9, method:equi-width, numbins:4}," // + + "\n{id:10, method:equi-width, numbins:4}," // + + "], dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); + new Transform(300, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10}").run(); + new Transform(300, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10, dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); + } System.exit(0); // forcefully stop. From 33fbc94070e4c3fd3376e50cb3cbddae60ff90de Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 12:51:14 +0100 Subject: [PATCH 05/81] nowWithCompressed Input --- .../sysds/performance/frame/Transform.java | 89 +++++++++++-------- 1 file changed, 51 insertions(+), 38 deletions(-) diff --git a/src/test/java/org/apache/sysds/performance/frame/Transform.java b/src/test/java/org/apache/sysds/performance/frame/Transform.java index c040eb742b6..be84d79649d 100644 --- a/src/test/java/org/apache/sysds/performance/frame/Transform.java +++ b/src/test/java/org/apache/sysds/performance/frame/Transform.java @@ -27,6 +27,7 @@ import org.apache.sysds.performance.generators.IGenerate; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.lib.FrameLibCompress; import org.apache.sysds.runtime.transform.encode.EncoderFactory; import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; import org.apache.sysds.test.TestUtils; @@ -87,57 +88,69 @@ public void clearRDCCache(FrameBlock f) { public static void main(String[] args) throws Exception { int k = InfrastructureAnalyzer.getLocalParallelism(); - for(int i = 1; i < 100; i *= 10) { + for(int i = 1; i < 1000; i *= 10) { FrameBlock in = TestUtils.generateRandomFrameBlock(100000 * i, new ValueType[] {ValueType.UINT4}, 32); - ConstFrame gen = new ConstFrame(in); - // // passthrough - new Transform(300, gen, k, "{}").run(); - new Transform(300, gen, k, "{ids:true, recode:[1]}").run(); - new Transform(300, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}").run(); - new Transform(300, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[1]}").run(); - new Transform(300, gen, k, "{ids:true, hash:[1], K:10}").run(); - new Transform(300, gen, k, "{ids:true, hash:[1], K:10, dummycode:[1]}").run(); + run(k, in); + FrameLibCompress.compress(in, k); + run(k, in); in = TestUtils.generateRandomFrameBlock( 100000 * i, new ValueType[] {ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4}, 32); - gen = new ConstFrame(in); - new Transform(300, gen, k, "{}").run(); - new Transform(300, gen, k, "{ids:true, recode:[1,2,3,4,5,6,7,8,9,10]}").run(); - new Transform(300, gen, k, "{ids:true, bin:[" // - + "\n{id:1, method:equi-width, numbins:4}," // - + "\n{id:2, method:equi-width, numbins:4}," // - + "\n{id:3, method:equi-width, numbins:4}," // - + "\n{id:4, method:equi-width, numbins:4}," // - + "\n{id:5, method:equi-width, numbins:4}," // - + "\n{id:6, method:equi-width, numbins:4}," // - + "\n{id:7, method:equi-width, numbins:4}," // - + "\n{id:8, method:equi-width, numbins:4}," // - + "\n{id:9, method:equi-width, numbins:4}," // - + "\n{id:10, method:equi-width, numbins:4}," // - + "]}").run(); - new Transform(300, gen, k, "{ids:true, bin:[" // - + "\n{id:1, method:equi-width, numbins:4}," // - + "\n{id:2, method:equi-width, numbins:4}," // - + "\n{id:3, method:equi-width, numbins:4}," // - + "\n{id:4, method:equi-width, numbins:4}," // - + "\n{id:5, method:equi-width, numbins:4}," // - + "\n{id:6, method:equi-width, numbins:4}," // - + "\n{id:7, method:equi-width, numbins:4}," // - + "\n{id:8, method:equi-width, numbins:4}," // - + "\n{id:9, method:equi-width, numbins:4}," // - + "\n{id:10, method:equi-width, numbins:4}," // - + "], dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); - new Transform(300, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10}").run(); - new Transform(300, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10, dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); + run10(k, in); + FrameLibCompress.compress(in, k); + run10(k, in); } System.exit(0); // forcefully stop. } + private static void run10(int k, FrameBlock in) throws Exception { + ConstFrame gen = new ConstFrame(in); + new Transform(300, gen, k, "{}").run(); + new Transform(300, gen, k, "{ids:true, recode:[1,2,3,4,5,6,7,8,9,10]}").run(); + new Transform(300, gen, k, "{ids:true, bin:[" // + + "\n{id:1, method:equi-width, numbins:4}," // + + "\n{id:2, method:equi-width, numbins:4}," // + + "\n{id:3, method:equi-width, numbins:4}," // + + "\n{id:4, method:equi-width, numbins:4}," // + + "\n{id:5, method:equi-width, numbins:4}," // + + "\n{id:6, method:equi-width, numbins:4}," // + + "\n{id:7, method:equi-width, numbins:4}," // + + "\n{id:8, method:equi-width, numbins:4}," // + + "\n{id:9, method:equi-width, numbins:4}," // + + "\n{id:10, method:equi-width, numbins:4}," // + + "]}").run(); + new Transform(300, gen, k, "{ids:true, bin:[" // + + "\n{id:1, method:equi-width, numbins:4}," // + + "\n{id:2, method:equi-width, numbins:4}," // + + "\n{id:3, method:equi-width, numbins:4}," // + + "\n{id:4, method:equi-width, numbins:4}," // + + "\n{id:5, method:equi-width, numbins:4}," // + + "\n{id:6, method:equi-width, numbins:4}," // + + "\n{id:7, method:equi-width, numbins:4}," // + + "\n{id:8, method:equi-width, numbins:4}," // + + "\n{id:9, method:equi-width, numbins:4}," // + + "\n{id:10, method:equi-width, numbins:4}," // + + "], dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); + new Transform(300, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10}").run(); + new Transform(300, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10, dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); + } + + private static void run(int k, FrameBlock in) throws Exception { + ConstFrame gen = new ConstFrame(in); + // // passthrough + new Transform(300, gen, k, "{}").run(); + new Transform(300, gen, k, "{ids:true, recode:[1]}").run(); + new Transform(300, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}").run(); + new Transform(300, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[1]}").run(); + new Transform(300, gen, k, "{ids:true, hash:[1], K:10}").run(); + new Transform(300, gen, k, "{ids:true, hash:[1], K:10, dummycode:[1]}").run(); + } + } From 793eca9b7b757a54838af68cc3b1745fa6a34e3e Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 12:53:50 +0100 Subject: [PATCH 06/81] more --- .../sysds/performance/frame/Transform.java | 31 +++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/test/java/org/apache/sysds/performance/frame/Transform.java b/src/test/java/org/apache/sysds/performance/frame/Transform.java index be84d79649d..1105214f699 100644 --- a/src/test/java/org/apache/sysds/performance/frame/Transform.java +++ b/src/test/java/org/apache/sysds/performance/frame/Transform.java @@ -88,11 +88,25 @@ public void clearRDCCache(FrameBlock f) { public static void main(String[] args) throws Exception { int k = InfrastructureAnalyzer.getLocalParallelism(); + FrameBlock in; + for(int i = 1; i < 1000; i *= 10) { - FrameBlock in = TestUtils.generateRandomFrameBlock(100000 * i, new ValueType[] {ValueType.UINT4}, 32); + in = TestUtils.generateRandomFrameBlock(100000 * i, new ValueType[] {ValueType.UINT4}, 32); + + System.out.println("Without null"); + run(k, in); + + System.out.println("Compressed without null"); + FrameLibCompress.compress(in, k); + run(k, in); + + in = TestUtils.generateRandomFrameBlock(100000 * i, new ValueType[] {ValueType.UINT4}, 32, 0.5); + + System.out.println("With null"); run(k, in); + System.out.println("Compressed with null"); FrameLibCompress.compress(in, k); run(k, in); @@ -101,10 +115,22 @@ public static void main(String[] args) throws Exception { ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4}, 32); + System.out.println("10 col without null"); run10(k, in); + System.out.println("10 col compressed without null"); FrameLibCompress.compress(in, k); run10(k, in); + in = TestUtils.generateRandomFrameBlock( + 100000 * i, new ValueType[] {ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, + ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4}, + 32, 0.5); + + System.out.println("10 col with null"); + run10(k, in); + System.out.println("10 col Compressed with null"); + FrameLibCompress.compress(in, k); + run10(k, in); } System.exit(0); // forcefully stop. @@ -139,7 +165,8 @@ private static void run10(int k, FrameBlock in) throws Exception { + "\n{id:10, method:equi-width, numbins:4}," // + "], dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); new Transform(300, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10}").run(); - new Transform(300, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10, dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); + new Transform(300, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10, dummycode:[1,2,3,4,5,6,7,8,9,10]}") + .run(); } private static void run(int k, FrameBlock in) throws Exception { From 9a88e66853b52b8422ad9ed4210af1dc4cd84dae Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 14:01:23 +0100 Subject: [PATCH 07/81] better count ubyte --- .../compress/colgroup/mapping/MapToUByte.java | 33 ++++-- .../compress/CompressedFrameBlockFactory.java | 2 +- .../transform/encode/CompressedEncode.java | 106 +++++++++-------- .../sysds/performance/frame/Transform.java | 112 +++++++++--------- 4 files changed, 135 insertions(+), 118 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToUByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToUByte.java index 97cbfdcde27..4ed649c17f6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToUByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToUByte.java @@ -115,24 +115,38 @@ public int getUpperBoundValue() { @Override public int[] getCounts(int[] ret) { - for(int i = 0; i < _data.length; i++) + final int h = (_data.length) % 8; + for(int i = 0; i < h; i++) ret[_data[i]]++; + getCountsBy8P(ret, h, _data.length); return ret; } + private void getCountsBy8P(int[] ret, int s, int e) { + for(int i = s; i < e; i += 8) { + ret[_data[i]]++; + ret[_data[i + 1]]++; + ret[_data[i + 2]]++; + ret[_data[i + 3]]++; + ret[_data[i + 4]]++; + ret[_data[i + 5]]++; + ret[_data[i + 6]]++; + ret[_data[i + 7]]++; + } + } + @Override protected void decompressToRangeNoOffBy8(double[] c, int r, double[] values) { c[r] += values[_data[r]]; - c[r+1] += values[_data[r+1]]; - c[r+2] += values[_data[r+2]]; - c[r+3] += values[_data[r+3]]; - c[r+4] += values[_data[r+4]]; - c[r+5] += values[_data[r+5]]; - c[r+6] += values[_data[r+6]]; - c[r+7] += values[_data[r+7]]; + c[r + 1] += values[_data[r + 1]]; + c[r + 2] += values[_data[r + 2]]; + c[r + 3] += values[_data[r + 3]]; + c[r + 4] += values[_data[r + 4]]; + c[r + 5] += values[_data[r + 5]]; + c[r + 6] += values[_data[r + 6]]; + c[r + 7] += values[_data[r + 7]]; } - @Override public void decompressToRange(double[] c, int rl, int ru, int offR, double[] values) { // OVERWRITTEN FOR JIT COMPILE! @@ -158,7 +172,6 @@ public void decompressToRangeNoOff(double[] c, int rl, int ru, double[] values) decompressToRangeNoOffBy8(c, rc, values); } - @Override public AMapToData resize(int unique) { final int size = _data.length; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java b/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java index de7031c7c01..6cddc728fa0 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java @@ -208,7 +208,7 @@ private Array compressColFinally(int i, final Array a, final ArrayCompress Timing time = LOG.isDebugEnabled() ? new Timing(true) : null; if(s.bestType != null && s.shouldCompress) { if(s.bestType == FrameArrayType.DDC) - compressedColumns[i] = DDCArray.compressToDDC(a); + compressedColumns[i] = DDCArray.compressToDDC(a, s.nUnique); else throw new RuntimeException("Unsupported frame compression encoding : " + s.bestType); } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 67cf2cee09a..c1d4a0f7749 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -26,6 +26,7 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicLong; import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; @@ -76,6 +77,8 @@ public class CompressedEncode { private final boolean inputContainsCompressed; + private final AtomicLong nnz = new AtomicLong(); + private CompressedEncode(MultiColumnEncoder enc, FrameBlock in, int k) { this.enc = enc; this.in = in; @@ -98,10 +101,10 @@ public static MatrixBlock encode(MultiColumnEncoder enc, FrameBlock in, int k) t private MatrixBlock apply() throws Exception { try { final List encoders = enc.getColumnEncoders(); - final List groups = singleThread(encoders); //isParallel() ? multiThread(encoders) : singleThread(encoders); + final List groups = isParallel() ? multiThread(encoders) : singleThread(encoders); final int cols = shiftGroups(groups); final MatrixBlock mb = new CompressedMatrixBlock(in.getNumRows(), cols, -1, false, groups); - mb.recomputeNonZeros(k); + mb.setNonZeros(nnz.get()); logging(mb); return mb; } @@ -181,7 +184,7 @@ private AColGroup recodeToDummy(ColumnEncoderComposite c) throws Exception { Array a = (Array) in.getColumn(colId - 1); boolean containsNull = a.containsNull(); estimateRCDMapSize(c); - Map map = a.getRecodeMap(c._estNumDistincts, CommonThreadPool.get(k)); + Map map = a.getRecodeMap(c._estNumDistincts, pool); List r = c.getEncoders(); r.set(0, new ColumnEncoderRecode(colId, (HashMap) map)); @@ -194,8 +197,9 @@ private AColGroup recodeToDummy(ColumnEncoderComposite c) throws Exception { ADictionary d = new IdentityDictionary(colIndexes.size(), containsNull); AMapToData m = createMappingAMapToData(a, map, containsNull); - - return ColGroupDDC.create(colIndexes, d, m, null); + AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); + nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); + return ret; } private AColGroup bin(ColumnEncoderComposite c) throws InterruptedException, ExecutionException { @@ -212,6 +216,7 @@ private AColGroup bin(ColumnEncoderComposite c) throws InterruptedException, Exe m = binEncode(a, b, containsNull); AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); + nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); return ret; } @@ -248,25 +253,25 @@ private void BinEncodeParallel(Array a, ColumnEncoderBin b, boolean nulls, fi throws InterruptedException, ExecutionException { final List> tasks = new ArrayList<>(); final int blockSize = Math.max(ROW_PARALLELIZATION_THRESHOLD / 2, rlen + k / k); - final ExecutorService pool = CommonThreadPool.get(k); - try { - - for(int i = 0; i < rlen; i += blockSize) { - final int start = i; - final int end = Math.min(rlen, i + blockSize); - tasks.add(pool.submit(() -> { - if(nulls) - binEncodeWithNulls(a, b, m, start, end); - else - binEncodeNoNull(a, b, m, start, end); - })); - } - for(Future t : tasks) - t.get(); - } - finally { - pool.shutdown(); + // final ExecutorService pool = CommonThreadPool.get(k); + // try { + + for(int i = 0; i < rlen; i += blockSize) { + final int start = i; + final int end = Math.min(rlen, i + blockSize); + tasks.add(pool.submit(() -> { + if(nulls) + binEncodeWithNulls(a, b, m, start, end); + else + binEncodeNoNull(a, b, m, start, end); + })); } + for(Future t : tasks) + t.get(); + // } + // finally { + // pool.shutdown(); + // } } private void binEncodeWithNulls(Array a, ColumnEncoderBin b, AMapToData m, int l, int u) { @@ -328,6 +333,7 @@ private AColGroup binToDummy(ColumnEncoderComposite c) throws InterruptedExcepti final AMapToData m; m = binEncode(a, b, containsNull); AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); + nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); return ret; } @@ -336,7 +342,7 @@ private AColGroup recode(ColumnEncoderComposite c) throws Exception { int colId = c._colID; Array a = (Array) in.getColumn(colId - 1); estimateRCDMapSize(c); - Map map = a.getRecodeMap(c._estNumDistincts, CommonThreadPool.get(k)); + Map map = a.getRecodeMap(c._estNumDistincts, pool); boolean containsNull = a.containsNull(); int domain = map.size(); @@ -356,8 +362,9 @@ private AColGroup recode(ColumnEncoderComposite c) throws Exception { List r = c.getEncoders(); r.set(0, new ColumnEncoderRecode(colId, (HashMap) map)); - return ColGroupDDC.create(colIndexes, d, m, null); - + AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); + nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); + return ret; } @SuppressWarnings("unchecked") @@ -384,13 +391,15 @@ private AColGroup passThroughNormal(ColumnEncoderComposite c, final IColInde double[] vals = (double[]) a.changeType(ValueType.FP64).get(); MatrixBlock col = new MatrixBlock(a.size(), 1, vals); - col.recomputeNonZeros(1); + long nz = col.recomputeNonZeros(1); + + nnz.addAndGet(nz); return ColGroupUncompressed.create(colIndexes, col, false); } else { boolean containsNull = a.containsNull(); estimateRCDMapSize(c); - Map map = a.getRecodeMap(c._estNumDistincts, CommonThreadPool.get(k)); + Map map = a.getRecodeMap(c._estNumDistincts, pool); double[] vals = new double[map.size() + (containsNull ? 1 : 0)]; if(containsNull) vals[map.size()] = Double.NaN; @@ -398,7 +407,9 @@ private AColGroup passThroughNormal(ColumnEncoderComposite c, final IColInde map.forEach((k, v) -> vals[v.intValue() - 1] = UtilFunctions.objectToDouble(t, k)); ADictionary d = Dictionary.create(vals); AMapToData m = createMappingAMapToData(a, map, containsNull); - return ColGroupDDC.create(colIndexes, d, m, null); + AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); + nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); + return ret; } } @@ -415,8 +426,10 @@ private AColGroup passThroughCompressed(final IColIndex colIndexes, final Ar vals[i] = dict.getAsDouble(i); ADictionary d = Dictionary.create(vals); + AColGroup ret = ColGroupDDC.create(colIndexes, d, aDDC.getMap(), null); - return ColGroupDDC.create(colIndexes, d, aDDC.getMap(), null); + nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); + return ret; } private AMapToData createMappingAMapToData(Array a, Map map, boolean containsNull) @@ -439,31 +452,24 @@ private AMapToData CreateMappingParallel(Array a, Map map, bo final int blkz = Math.max(ROW_PARALLELIZATION_THRESHOLD / 2, (nRow + k) / k); List> tasks = new ArrayList<>(); - // make a thread local pool. - // this pool is independent of the potential generally shared pool - ExecutorService pool = CommonThreadPool.get(k); - try { - for(int i = 0; i < nRow; i += blkz) { - final int start = i; - final int end = Math.min(nRow, i + blkz); - tasks.add(pool.submit(() -> { - if(containsNull) - return createMappingAMapToDataWithNull(a, map, si, m, start, end); - else - return createMappingAMapToDataNoNull(a, map, m, start, end); + for(int i = 0; i < nRow; i += blkz) { + final int start = i; + final int end = Math.min(nRow, i + blkz); - })); - } + tasks.add(pool.submit(() -> { + if(containsNull) + return createMappingAMapToDataWithNull(a, map, si, m, start, end); + else + return createMappingAMapToDataNoNull(a, map, m, start, end); - for(Future t : tasks) - t.get(); - return m; - } - finally { - pool.shutdown(); + })); } + for(Future t : tasks) + t.get(); + return m; + } private AMapToData createMappingSingleThread(Array a, Map map, boolean containsNull, final int si, diff --git a/src/test/java/org/apache/sysds/performance/frame/Transform.java b/src/test/java/org/apache/sysds/performance/frame/Transform.java index 1105214f699..ab030121e29 100644 --- a/src/test/java/org/apache/sysds/performance/frame/Transform.java +++ b/src/test/java/org/apache/sysds/performance/frame/Transform.java @@ -49,10 +49,8 @@ public Transform(int N, IGenerate gen, int k, String spec) { } public void run() throws Exception { - execute(() -> te(), () -> clear(), "Normal"); - execute(() -> tec(), () -> clear(), "Compressed"); // execute(() -> te(), () -> clear(), "Normal"); - // execute(() -> tec(), () -> clear(), "Compressed"); + execute(() -> tec(), () -> clear(), "Compressed"); } private void te() { @@ -92,33 +90,33 @@ public static void main(String[] args) throws Exception { for(int i = 1; i < 1000; i *= 10) { - in = TestUtils.generateRandomFrameBlock(100000 * i, new ValueType[] {ValueType.UINT4}, 32); + // in = TestUtils.generateRandomFrameBlock(100000 * i, new ValueType[] {ValueType.UINT4}, 32); - System.out.println("Without null"); - run(k, in); + // System.out.println("Without null"); + // run(k, in); - System.out.println("Compressed without null"); - FrameLibCompress.compress(in, k); - run(k, in); + // System.out.println("Compressed without null"); + // in = FrameLibCompress.compress(in, k); + // run(k, in); - in = TestUtils.generateRandomFrameBlock(100000 * i, new ValueType[] {ValueType.UINT4}, 32, 0.5); + // in = TestUtils.generateRandomFrameBlock(100000 * i, new ValueType[] {ValueType.UINT4}, 32, 0.5); - System.out.println("With null"); + // System.out.println("With null"); - run(k, in); - System.out.println("Compressed with null"); - FrameLibCompress.compress(in, k); - run(k, in); + // run(k, in); + // System.out.println("Compressed with null"); + // in = FrameLibCompress.compress(in, k); + // run(k, in); in = TestUtils.generateRandomFrameBlock( 100000 * i, new ValueType[] {ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4}, 32); - System.out.println("10 col without null"); - run10(k, in); + // System.out.println("10 col without null"); + // run10(k, in); System.out.println("10 col compressed without null"); - FrameLibCompress.compress(in, k); + in = FrameLibCompress.compress(in, k); run10(k, in); in = TestUtils.generateRandomFrameBlock( @@ -126,11 +124,11 @@ public static void main(String[] args) throws Exception { ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4}, 32, 0.5); - System.out.println("10 col with null"); - run10(k, in); - System.out.println("10 col Compressed with null"); - FrameLibCompress.compress(in, k); - run10(k, in); + // System.out.println("10 col with null"); + // run10(k, in); + // System.out.println("10 col Compressed with null"); + // in = FrameLibCompress.compress(in, k); + // run10(k, in); } System.exit(0); // forcefully stop. @@ -139,45 +137,45 @@ public static void main(String[] args) throws Exception { private static void run10(int k, FrameBlock in) throws Exception { ConstFrame gen = new ConstFrame(in); new Transform(300, gen, k, "{}").run(); - new Transform(300, gen, k, "{ids:true, recode:[1,2,3,4,5,6,7,8,9,10]}").run(); - new Transform(300, gen, k, "{ids:true, bin:[" // - + "\n{id:1, method:equi-width, numbins:4}," // - + "\n{id:2, method:equi-width, numbins:4}," // - + "\n{id:3, method:equi-width, numbins:4}," // - + "\n{id:4, method:equi-width, numbins:4}," // - + "\n{id:5, method:equi-width, numbins:4}," // - + "\n{id:6, method:equi-width, numbins:4}," // - + "\n{id:7, method:equi-width, numbins:4}," // - + "\n{id:8, method:equi-width, numbins:4}," // - + "\n{id:9, method:equi-width, numbins:4}," // - + "\n{id:10, method:equi-width, numbins:4}," // - + "]}").run(); - new Transform(300, gen, k, "{ids:true, bin:[" // - + "\n{id:1, method:equi-width, numbins:4}," // - + "\n{id:2, method:equi-width, numbins:4}," // - + "\n{id:3, method:equi-width, numbins:4}," // - + "\n{id:4, method:equi-width, numbins:4}," // - + "\n{id:5, method:equi-width, numbins:4}," // - + "\n{id:6, method:equi-width, numbins:4}," // - + "\n{id:7, method:equi-width, numbins:4}," // - + "\n{id:8, method:equi-width, numbins:4}," // - + "\n{id:9, method:equi-width, numbins:4}," // - + "\n{id:10, method:equi-width, numbins:4}," // - + "], dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); - new Transform(300, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10}").run(); - new Transform(300, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10, dummycode:[1,2,3,4,5,6,7,8,9,10]}") - .run(); + // new Transform(300, gen, k, "{ids:true, recode:[1,2,3,4,5,6,7,8,9,10]}").run(); + // new Transform(300, gen, k, "{ids:true, bin:[" // + // + "\n{id:1, method:equi-width, numbins:4}," // + // + "\n{id:2, method:equi-width, numbins:4}," // + // + "\n{id:3, method:equi-width, numbins:4}," // + // + "\n{id:4, method:equi-width, numbins:4}," // + // + "\n{id:5, method:equi-width, numbins:4}," // + // + "\n{id:6, method:equi-width, numbins:4}," // + // + "\n{id:7, method:equi-width, numbins:4}," // + // + "\n{id:8, method:equi-width, numbins:4}," // + // + "\n{id:9, method:equi-width, numbins:4}," // + // + "\n{id:10, method:equi-width, numbins:4}," // + // + "]}").run(); + // new Transform(300, gen, k, "{ids:true, bin:[" // + // + "\n{id:1, method:equi-width, numbins:4}," // + // + "\n{id:2, method:equi-width, numbins:4}," // + // + "\n{id:3, method:equi-width, numbins:4}," // + // + "\n{id:4, method:equi-width, numbins:4}," // + // + "\n{id:5, method:equi-width, numbins:4}," // + // + "\n{id:6, method:equi-width, numbins:4}," // + // + "\n{id:7, method:equi-width, numbins:4}," // + // + "\n{id:8, method:equi-width, numbins:4}," // + // + "\n{id:9, method:equi-width, numbins:4}," // + // + "\n{id:10, method:equi-width, numbins:4}," // + // + "], dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); + // new Transform(300, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10}").run(); + // new Transform(300, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10, dummycode:[1,2,3,4,5,6,7,8,9,10]}") + // .run(); } private static void run(int k, FrameBlock in) throws Exception { ConstFrame gen = new ConstFrame(in); // // passthrough - new Transform(300, gen, k, "{}").run(); - new Transform(300, gen, k, "{ids:true, recode:[1]}").run(); - new Transform(300, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}").run(); - new Transform(300, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[1]}").run(); - new Transform(300, gen, k, "{ids:true, hash:[1], K:10}").run(); - new Transform(300, gen, k, "{ids:true, hash:[1], K:10, dummycode:[1]}").run(); + // new Transform(300, gen, k, "{}").run(); + // new Transform(300, gen, k, "{ids:true, recode:[1]}").run(); + // new Transform(300, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}").run(); + // new Transform(300, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[1]}").run(); + // new Transform(300, gen, k, "{ids:true, hash:[1], K:10}").run(); + // new Transform(300, gen, k, "{ids:true, hash:[1], K:10, dummycode:[1]}").run(); } } From 1a5b6e9bffafc51b3885dee7cf33cb74af8388e5 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 14:03:02 +0100 Subject: [PATCH 08/81] count by 8 byte --- .../compress/colgroup/mapping/MapToByte.java | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java index f8569be67b2..fea31022819 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java @@ -166,11 +166,26 @@ public void copyBit(MapToBit d) { @Override public int[] getCounts(int[] ret) { - for(int i = 0; i < _data.length; i++) + final int h = (_data.length) % 8; + for(int i = 0; i < h; i++) ret[_data[i] & 0xFF]++; + getCountsBy8P(ret, h, _data.length); return ret; } + private void getCountsBy8P(int[] ret, int s, int e) { + for(int i = s; i < e; i += 8) { + ret[_data[i] & 0xFF]++; + ret[_data[i + 1] & 0xFF]++; + ret[_data[i + 2] & 0xFF]++; + ret[_data[i + 3] & 0xFF]++; + ret[_data[i + 4] & 0xFF]++; + ret[_data[i + 5] & 0xFF]++; + ret[_data[i + 6] & 0xFF]++; + ret[_data[i + 7] & 0xFF]++; + } + } + @Override protected void preAggregateDenseToRowBy8(double[] mV, double[] preAV, int cl, int cu, int off) { final int h = (cu - cl) % 8; From f9c14fba5458f762804c62584a02d56c8acd8ff8 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 14:07:50 +0100 Subject: [PATCH 09/81] by8 count --- .../compress/colgroup/mapping/MapToChar.java | 18 ++++++++++++- .../colgroup/mapping/MapToCharPByte.java | 26 ++++++++++++++----- .../compress/colgroup/mapping/MapToInt.java | 18 ++++++++++++- .../compress/colgroup/mapping/MapToZero.java | 3 +-- 4 files changed, 54 insertions(+), 11 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java index fb6317ec1a3..f1767dc06d8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java @@ -230,11 +230,27 @@ public void copyInt(int[] d, int start, int end) { @Override public int[] getCounts(int[] ret) { - for(int i = 0; i < _data.length; i++) + final int h = (_data.length) % 8; + for(int i = 0; i < h; i++) ret[_data[i]]++; + getCountsBy8P(ret, h, _data.length); return ret; } + private void getCountsBy8P(int[] ret, int s, int e) { + for(int i = s; i < e; i += 8) { + ret[_data[i]]++; + ret[_data[i + 1]]++; + ret[_data[i + 2]]++; + ret[_data[i + 3]]++; + ret[_data[i + 4]]++; + ret[_data[i + 5]]++; + ret[_data[i + 6]]++; + ret[_data[i + 7]]++; + } + } + + @Override public AMapToData resize(int unique) { final int size = _data.length; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java index bd66667d6a9..8c6fb6366ea 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java @@ -37,7 +37,7 @@ public class MapToCharPByte extends AMapToData { private static final long serialVersionUID = 6315708056775476541L; - public static final int max = (0xFFFF + 1) * 128 -1; + public static final int max = (0xFFFF + 1) * 128 - 1; private final char[] _data_c; private final byte[] _data_b; // next byte after the char @@ -100,7 +100,7 @@ public void set(int n, int v) { } @Override - public void set(int l, int u, int off, AMapToData tm){ + public void set(int l, int u, int off, AMapToData tm) { for(int i = l; i < u; i++, off++) { set(i, tm.getIndex(off)); } @@ -169,14 +169,28 @@ public void copyInt(int[] d, int start, int end) { set(i, d[i]); } - @Override public int[] getCounts(int[] ret) { - for(int i = 0; i < size(); i++) - ret[getIndex(i)]++; + final int h = (size()) % 8; + for(int i = 0; i < h; i++) + ret[_data_c[i] + ((int) _data_b[i] << 16)]++; + getCountsBy8P(ret, h, size()); return ret; } + private void getCountsBy8P(int[] ret, int s, int e) { + for(int i = s; i < e; i += 8) { + ret[getIndex(i)]++; + ret[getIndex(i + 1)]++; + ret[getIndex(i + 2)]++; + ret[getIndex(i + 3)]++; + ret[getIndex(i + 4)]++; + ret[getIndex(i + 5)]++; + ret[getIndex(i + 6)]++; + ret[getIndex(i + 7)]++; + } + } + @Override public AMapToData resize(int unique) { final int size = _data_c.length; @@ -263,7 +277,6 @@ public AMapToData appendN(IMapToDataGroup[] d) { } - @Override public boolean equals(AMapToData e) { return e instanceof MapToCharPByte && // @@ -326,7 +339,6 @@ protected void decompressToRangeNoOffBy8(double[] c, int r, double[] values) { c[r + 7] += values[getIndex(r + 7)]; } - @Override public void decompressToRange(double[] c, int rl, int ru, int offR, double[] values) { // OVERWRITTEN FOR JIT COMPILE! diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java index 3dcec05e373..8997b643b3a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java @@ -195,11 +195,27 @@ public void copyInt(int[] d, int start, int end) { @Override public int[] getCounts(int[] ret) { - for(int i = 0; i < _data.length; i++) + final int h = (_data.length) % 8; + for(int i = 0; i < h; i++) ret[_data[i]]++; + getCountsBy8P(ret, h, _data.length); return ret; } + private void getCountsBy8P(int[] ret, int s, int e) { + for(int i = s; i < e; i += 8) { + ret[_data[i]]++; + ret[_data[i + 1]]++; + ret[_data[i + 2]]++; + ret[_data[i + 3]]++; + ret[_data[i + 4]]++; + ret[_data[i + 5]]++; + ret[_data[i + 6]]++; + ret[_data[i + 7]]++; + } + } + + @Override public int countRuns() { int c = 1; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToZero.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToZero.java index b839fc336c2..fa751bd5057 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToZero.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToZero.java @@ -110,8 +110,7 @@ public int getUpperBoundValue() { @Override public int[] getCounts(int[] ret) { - final int sz = size(); - ret[0] = sz; + ret[0] = size(); return ret; } From b33dd7e39c969dd7cd30b188db05a7d97610f7be Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 14:41:23 +0100 Subject: [PATCH 10/81] safety tk settings --- .../runtime/frame/data/columns/Array.java | 7 +++--- .../transform/encode/CompressedEncode.java | 23 ++++++++----------- .../component/frame/array/RecodeMapTest.java | 6 ++--- 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index 0757d24b525..2cfbfcdebff 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -108,7 +108,7 @@ public synchronized final Map getRecodeMap() { */ public synchronized final Map getRecodeMap(int estimate) { try { - return getRecodeMap(estimate, null); + return getRecodeMap(estimate, null, -1); } catch(Exception e) { throw new RuntimeException(e); @@ -122,11 +122,12 @@ public synchronized final Map getRecodeMap(int estimate) { * * @param estimate the estimated number of unique values in this array. * @param pool An executor pool to be used for parallel execution (Note this method does not shutdown the pool) + * @param k Parallelization degree allowed * @return A recode map * @throws ExecutionException if the parallel execution fails * @throws InterruptedException if the parallel execution fails */ - public synchronized final Map getRecodeMap(int estimate, ExecutorService pool) + public synchronized final Map getRecodeMap(int estimate, ExecutorService pool, int k) throws InterruptedException, ExecutionException { // probe cache for existing map Map map; @@ -162,7 +163,7 @@ protected Map createRecodeMap(int estimate, ExecutorService pool) final int s = size(); int k = OptimizerUtils.getTransformNumThreads(); Map ret; - if(pool == null || s < ROW_PARALLELIZATION_THRESHOLD) + if(k <= 1 || pool == null || s < ROW_PARALLELIZATION_THRESHOLD) ret = createRecodeMap(estimate, 0, s); else ret = parallelCreateRecodeMap(estimate, pool, s, k); diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index c1d4a0f7749..fa448f6adfc 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -184,7 +184,7 @@ private AColGroup recodeToDummy(ColumnEncoderComposite c) throws Exception { Array a = (Array) in.getColumn(colId - 1); boolean containsNull = a.containsNull(); estimateRCDMapSize(c); - Map map = a.getRecodeMap(c._estNumDistincts, pool); + Map map = a.getRecodeMap(c._estNumDistincts, pool, k / in.getNumColumns() ); List r = c.getEncoders(); r.set(0, new ColumnEncoderRecode(colId, (HashMap) map)); @@ -236,7 +236,7 @@ private AMapToData binEncode(Array a, ColumnEncoderBin b, boolean nulls) } final int rlen = a.size(); - if(k > 1 && rlen > ROW_PARALLELIZATION_THRESHOLD) { + if(k / in.getNumColumns() > 1 && rlen > ROW_PARALLELIZATION_THRESHOLD) { BinEncodeParallel(a, b, nulls, m, rlen); } else { @@ -252,9 +252,8 @@ private AMapToData binEncode(Array a, ColumnEncoderBin b, boolean nulls) private void BinEncodeParallel(Array a, ColumnEncoderBin b, boolean nulls, final AMapToData m, final int rlen) throws InterruptedException, ExecutionException { final List> tasks = new ArrayList<>(); - final int blockSize = Math.max(ROW_PARALLELIZATION_THRESHOLD / 2, rlen + k / k); - // final ExecutorService pool = CommonThreadPool.get(k); - // try { + final int tk = k / in.getNumColumns(); + final int blockSize = Math.max(ROW_PARALLELIZATION_THRESHOLD / 2, rlen + tk / tk); for(int i = 0; i < rlen; i += blockSize) { final int start = i; @@ -268,10 +267,7 @@ private void BinEncodeParallel(Array a, ColumnEncoderBin b, boolean nulls, fi } for(Future t : tasks) t.get(); - // } - // finally { - // pool.shutdown(); - // } + } private void binEncodeWithNulls(Array a, ColumnEncoderBin b, AMapToData m, int l, int u) { @@ -342,7 +338,7 @@ private AColGroup recode(ColumnEncoderComposite c) throws Exception { int colId = c._colID; Array a = (Array) in.getColumn(colId - 1); estimateRCDMapSize(c); - Map map = a.getRecodeMap(c._estNumDistincts, pool); + Map map = a.getRecodeMap(c._estNumDistincts, pool, k / in.getNumColumns() ); boolean containsNull = a.containsNull(); int domain = map.size(); @@ -399,7 +395,7 @@ private AColGroup passThroughNormal(ColumnEncoderComposite c, final IColInde else { boolean containsNull = a.containsNull(); estimateRCDMapSize(c); - Map map = a.getRecodeMap(c._estNumDistincts, pool); + Map map = a.getRecodeMap(c._estNumDistincts, pool, k / in.getNumColumns() ); double[] vals = new double[map.size() + (containsNull ? 1 : 0)]; if(containsNull) vals[map.size()] = Double.NaN; @@ -441,7 +437,7 @@ private AMapToData createMappingAMapToData(Array a, Map map, final AMapToData m = MapToFactory.create(nRow, si + (containsNull ? 1 : 0)); - if(k > 1 && nRow > ROW_PARALLELIZATION_THRESHOLD) + if(k / in.getNumColumns() > 1 && nRow > ROW_PARALLELIZATION_THRESHOLD) return CreateMappingParallel(a, map, containsNull, si, nRow, m); else return createMappingSingleThread(a, map, containsNull, si, nRow, m); @@ -449,7 +445,8 @@ private AMapToData createMappingAMapToData(Array a, Map map, private AMapToData CreateMappingParallel(Array a, Map map, boolean containsNull, final int si, final int nRow, final AMapToData m) throws InterruptedException, ExecutionException { - final int blkz = Math.max(ROW_PARALLELIZATION_THRESHOLD / 2, (nRow + k) / k); + final int tk = k / in.getNumColumns(); + final int blkz = Math.max(ROW_PARALLELIZATION_THRESHOLD / 2, (nRow + tk) / tk); List> tasks = new ArrayList<>(); diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/RecodeMapTest.java b/src/test/java/org/apache/sysds/test/component/frame/array/RecodeMapTest.java index 5cac2f22201..1d8f8fdd7c2 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/RecodeMapTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/RecodeMapTest.java @@ -64,7 +64,7 @@ public void createRecodeMapParallel() throws Exception { Logger.getLogger(Array.class).setLevel(Level.DEBUG); Array a = ArrayFactory.create(FrameArrayTests.generateRandomStringNUnique(1000, 324, 10)); - Map rcm = a.getRecodeMap(10, CommonThreadPool.get(10)); + Map rcm = a.getRecodeMap(10, CommonThreadPool.get(10), 10); assertTrue(rcm.size() == 10); final List log = LoggingUtils.reinsert(appender); assertTrue(log.size() >= 1); @@ -85,8 +85,8 @@ public void createRecodeMapParallel2() throws Exception { Logger.getLogger(Array.class).setLevel(Level.DEBUG); Array a = ArrayFactory.create(FrameArrayTests.generateRandomStringNUnique(1000, 324, 500)); - Map rcm = a.getRecodeMap(10, CommonThreadPool.get(10)); - Map rcm2 = a.getRecodeMap(10, null); + Map rcm = a.getRecodeMap(10, CommonThreadPool.get(10), 10); + Map rcm2 = a.getRecodeMap(10, null, -1); assertTrue(Math.abs(rcm.size() - 500) < 100); assertTrue(rcm.size() == rcm2.size()); From 59ca2426b844cfe050b333f78501557fa1268e93 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 15:08:39 +0100 Subject: [PATCH 11/81] fix hash nnz --- .../frame/data/columns/ABooleanArray.java | 2 +- .../runtime/frame/data/columns/Array.java | 23 ++++++++----------- .../runtime/frame/data/columns/DDCArray.java | 4 ++-- .../frame/data/columns/StringArray.java | 4 ++-- .../transform/encode/CompressedEncode.java | 5 +++- .../sysds/performance/frame/Transform.java | 8 +++---- 6 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java index 848bc38796b..09128ccaa37 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java @@ -59,7 +59,7 @@ public boolean possiblyContainsNaN() { @Override - protected Map createRecodeMap(int estimate, ExecutorService pool) { + protected Map createRecodeMap(int estimate, ExecutorService pool, int k) { Map map = new HashMap<>(); int id = 1; for(int i = 0; i < size() && id <= 2; i++) { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index 2cfbfcdebff..ebd186e782d 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -34,7 +34,6 @@ import org.apache.commons.logging.LogFactory; import org.apache.hadoop.io.Writable; import org.apache.sysds.common.Types.ValueType; -import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.compress.estim.sample.SampleEstimatorFactory; @@ -137,7 +136,7 @@ public synchronized final Map getRecodeMap(int estimate, ExecutorSer return map; // construct recode map - map = createRecodeMap(estimate, pool); + map = createRecodeMap(estimate, pool, k); // put created map into cache setCache(new SoftReference<>(map)); @@ -153,22 +152,23 @@ public synchronized final Map getRecodeMap(int estimate, ExecutorSer * @param estimate The estimate number of unique values inside this array. * @param pool The thread pool to use for parallel creation of recode map (can be null). (Note this method does * not shutdown the pool) + * @param k The allowed degree of parallelism * @return The recode map created. * @throws ExecutionException if the parallel execution fails * @throws InterruptedException if the parallel execution fails */ - protected Map createRecodeMap(int estimate, ExecutorService pool) + protected Map createRecodeMap(int estimate, ExecutorService pool, int k) throws InterruptedException, ExecutionException { - Timing t = new Timing(); + final boolean debug = LOG.isDebugEnabled(); + final Timing t = debug ? new Timing() : null; final int s = size(); - int k = OptimizerUtils.getTransformNumThreads(); - Map ret; + final Map ret; if(k <= 1 || pool == null || s < ROW_PARALLELIZATION_THRESHOLD) ret = createRecodeMap(estimate, 0, s); else ret = parallelCreateRecodeMap(estimate, pool, s, k); - if(LOG.isDebugEnabled()) { + if(debug) { String base = "CreateRecodeMap estimate: %10d actual %10d time: %10.5f"; LOG.debug(String.format(base, estimate, ret.size(), t.stop())); } @@ -230,12 +230,9 @@ private Map createRecodeMap(Map map, final int s, final } protected int addValRecodeMap(Map map, int id, int i) { - T val = getInternal(i); - if(val != null) { - Integer v = map.putIfAbsent(val, id); - if(v == null) - id++; - } + final T val = getInternal(i); + if(val != null && map.putIfAbsent(val, id) == null) + id++; return id; } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java index ecd827070b3..40013829637 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java @@ -178,9 +178,9 @@ public static Array compressToDDC(Array arr, int estimateUnique) { } @Override - protected Map createRecodeMap(int estimate, ExecutorService pool) + protected Map createRecodeMap(int estimate, ExecutorService pool, int k) throws InterruptedException, ExecutionException { - return dict.createRecodeMap(estimate, pool); + return dict.createRecodeMap(estimate, pool, k); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index f1ef7943498..601fc78b2fc 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -669,7 +669,7 @@ public final boolean isNotEmpty(int i) { } @Override - protected Map createRecodeMap(int estimate, ExecutorService pool) throws InterruptedException, ExecutionException { + protected Map createRecodeMap(int estimate, ExecutorService pool, int k) throws InterruptedException, ExecutionException { try { Map map = new HashMap<>((int) Math.min((long) estimate * 2, size())); for(int i = 0; i < size(); i++) { @@ -682,7 +682,7 @@ protected Map createRecodeMap(int estimate, ExecutorService poo return map; } catch(Exception e) { - return super.createRecodeMap(estimate, pool); + return super.createRecodeMap(estimate, pool, k); } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index fa448f6adfc..5ade585dbb6 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -531,6 +531,7 @@ private AColGroup hash(ColumnEncoderComposite c) { AMapToData m = createHashMappingAMapToData(a, domain, nulls); AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); + nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); return ret; } @@ -545,7 +546,9 @@ private AColGroup hashToDummy(ColumnEncoderComposite c) { return ColGroupConst.create(colIndexes, new double[] {1}); ADictionary d = new IdentityDictionary(colIndexes.size(), nulls); AMapToData m = createHashMappingAMapToData(a, domain, nulls); - return ColGroupDDC.create(colIndexes, d, m, null); + AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); + nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); + return ret; } @SuppressWarnings("unchecked") diff --git a/src/test/java/org/apache/sysds/performance/frame/Transform.java b/src/test/java/org/apache/sysds/performance/frame/Transform.java index ab030121e29..c6159a1a329 100644 --- a/src/test/java/org/apache/sysds/performance/frame/Transform.java +++ b/src/test/java/org/apache/sysds/performance/frame/Transform.java @@ -113,11 +113,11 @@ public static void main(String[] args) throws Exception { ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4}, 32); - // System.out.println("10 col without null"); - // run10(k, in); - System.out.println("10 col compressed without null"); - in = FrameLibCompress.compress(in, k); + System.out.println("10 col without null"); run10(k, in); + // System.out.println("10 col compressed without null"); + // in = FrameLibCompress.compress(in, k); + // run10(k, in); in = TestUtils.generateRandomFrameBlock( 100000 * i, new ValueType[] {ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, From 045880888539286497a478d9f12d356fa648c0c8 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 18:23:24 +0100 Subject: [PATCH 12/81] HashMapToInt --- .../compress/colgroup/mapping/MapToChar.java | 18 +- .../frame/data/columns/ABooleanArray.java | 10 +- .../runtime/frame/data/columns/Array.java | 66 ++--- .../runtime/frame/data/columns/DDCArray.java | 2 +- .../frame/data/columns/HashIntegerArray.java | 7 +- .../frame/data/columns/HashLongArray.java | 15 +- .../frame/data/columns/HashMapToInt.java | 243 ++++++++++++++++++ .../frame/data/columns/IntegerArray.java | 7 + .../frame/data/columns/OptionalArray.java | 7 +- .../frame/data/columns/StringArray.java | 6 +- .../transform/encode/ColumnEncoderRecode.java | 25 +- .../transform/encode/CompressedEncode.java | 66 +++-- .../transform/encode/MultiColumnEncoder.java | 4 +- .../sysds/performance/frame/Transform.java | 3 +- 14 files changed, 374 insertions(+), 105 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java index f1767dc06d8..5982157726d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java @@ -232,21 +232,21 @@ public void copyInt(int[] d, int start, int end) { public int[] getCounts(int[] ret) { final int h = (_data.length) % 8; for(int i = 0; i < h; i++) - ret[_data[i]]++; + ret[getIndex(i)]++; getCountsBy8P(ret, h, _data.length); return ret; } private void getCountsBy8P(int[] ret, int s, int e) { for(int i = s; i < e; i += 8) { - ret[_data[i]]++; - ret[_data[i + 1]]++; - ret[_data[i + 2]]++; - ret[_data[i + 3]]++; - ret[_data[i + 4]]++; - ret[_data[i + 5]]++; - ret[_data[i + 6]]++; - ret[_data[i + 7]]++; + ret[getIndex(i)]++; + ret[getIndex(i + 1)]++; + ret[getIndex(i + 2)]++; + ret[getIndex(i + 3)]++; + ret[getIndex(i + 4)]++; + ret[getIndex(i + 5)]++; + ret[getIndex(i + 6)]++; + ret[getIndex(i + 7)]++; } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java index 09128ccaa37..97baf0d7a5b 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/ABooleanArray.java @@ -19,8 +19,6 @@ package org.apache.sysds.runtime.frame.data.columns; -import java.util.HashMap; -import java.util.Map; import java.util.concurrent.ExecutorService; public abstract class ABooleanArray extends Array { @@ -59,12 +57,12 @@ public boolean possiblyContainsNaN() { @Override - protected Map createRecodeMap(int estimate, ExecutorService pool, int k) { - Map map = new HashMap<>(); + protected HashMapToInt createRecodeMap(int estimate, ExecutorService pool, int k) { + HashMapToInt map = new HashMapToInt(2); int id = 1; for(int i = 0; i < size() && id <= 2; i++) { - Integer v = map.putIfAbsent(get(i), id); - if(v == null) + int v = map.putIfAbsentI(get(i), id); + if(v == -1) id++; } return map; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index ebd186e782d..d34b2ad6153 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -53,7 +53,7 @@ public abstract class Array implements Writable { public static int ROW_PARALLELIZATION_THRESHOLD = 10000; /** A soft reference to a memorization of this arrays mapping, used in transformEncode */ - protected SoftReference> _rcdMapCache = null; + protected SoftReference> _rcdMapCache = null; /** The current allocated number of elements in this Array */ protected int _size; @@ -73,7 +73,7 @@ protected int newSize() { * * @return The cached recode map */ - public final SoftReference> getCache() { + public final SoftReference> getCache() { return _rcdMapCache; } @@ -82,7 +82,7 @@ public final SoftReference> getCache() { * * @param m The element to cache. */ - public final void setCache(SoftReference> m) { + public final void setCache(SoftReference> m) { _rcdMapCache = m; } @@ -121,16 +121,16 @@ public synchronized final Map getRecodeMap(int estimate) { * * @param estimate the estimated number of unique values in this array. * @param pool An executor pool to be used for parallel execution (Note this method does not shutdown the pool) - * @param k Parallelization degree allowed + * @param k Parallelization degree allowed * @return A recode map * @throws ExecutionException if the parallel execution fails * @throws InterruptedException if the parallel execution fails */ - public synchronized final Map getRecodeMap(int estimate, ExecutorService pool, int k) + public synchronized final Map getRecodeMap(int estimate, ExecutorService pool, int k) throws InterruptedException, ExecutionException { // probe cache for existing map - Map map; - SoftReference> tmp = getCache(); + Map map; + SoftReference> tmp = getCache(); map = (tmp != null) ? tmp.get() : null; if(map != null) return map; @@ -152,17 +152,17 @@ public synchronized final Map getRecodeMap(int estimate, ExecutorSer * @param estimate The estimate number of unique values inside this array. * @param pool The thread pool to use for parallel creation of recode map (can be null). (Note this method does * not shutdown the pool) - * @param k The allowed degree of parallelism + * @param k The allowed degree of parallelism * @return The recode map created. * @throws ExecutionException if the parallel execution fails * @throws InterruptedException if the parallel execution fails */ - protected Map createRecodeMap(int estimate, ExecutorService pool, int k) + protected HashMapToInt createRecodeMap(int estimate, ExecutorService pool, int k) throws InterruptedException, ExecutionException { final boolean debug = LOG.isDebugEnabled(); final Timing t = debug ? new Timing() : null; final int s = size(); - final Map ret; + final HashMapToInt ret; if(k <= 1 || pool == null || s < ROW_PARALLELIZATION_THRESHOLD) ret = createRecodeMap(estimate, 0, s); else @@ -175,21 +175,21 @@ protected Map createRecodeMap(int estimate, ExecutorService pool, in return ret; } - private Map parallelCreateRecodeMap(int estimate, ExecutorService pool, final int s, int k) + private HashMapToInt parallelCreateRecodeMap(int estimate, ExecutorService pool, final int s, int k) throws InterruptedException, ExecutionException { final int blk = Math.max(ROW_PARALLELIZATION_THRESHOLD / 2, (s + k) / k); - final List>> tasks = new ArrayList<>(); + final List>> tasks = new ArrayList<>(); for(int i = blk; i < s; i += blk) { // start at blk for the other threads final int start = i; final int end = Math.min(i + blk, s); tasks.add(pool.submit(() -> createRecodeMap(estimate, start, end))); } // make the initial map thread local allocation. - final Map map = new HashMap<>((int) (estimate * 1.3)); + final HashMapToInt map = new HashMapToInt((int) (estimate * 1.3)); createRecodeMap(map, 0, blk); for(int i = 0; i < tasks.size(); i++) { // merge with other threads work. - final Map map2 = tasks.get(i).get(); + final HashMapToInt map2 = tasks.get(i).get(); mergeRecodeMaps(map, map2); } return map; @@ -216,22 +216,22 @@ protected static void mergeRecodeMaps(Map target, Map createRecodeMap(final int estimate, final int s, final int e) { + protected HashMapToInt createRecodeMap(final int estimate, final int s, final int e) { // * 1.3 because we hashMap has a load factor of 1.75 - final Map map = new HashMap<>((int) (Math.min((long) estimate, (e - s)) * 1.3)); + final HashMapToInt map = new HashMapToInt<>((int) (Math.min((long) estimate, (e - s)) * 1.3)); return createRecodeMap(map, s, e); } - private Map createRecodeMap(Map map, final int s, final int e) { + protected HashMapToInt createRecodeMap(HashMapToInt map, final int s, final int e) { int id = 1; for(int i = s; i < e; i++) id = addValRecodeMap(map, id, i); return map; } - protected int addValRecodeMap(Map map, int id, int i) { + protected int addValRecodeMap(HashMapToInt map, int id, int i) { final T val = getInternal(i); - if(val != null && map.putIfAbsent(val, id) == null) + if(val != null && map.putIfAbsentI(val, id) == -1) id++; return id; } @@ -1040,8 +1040,8 @@ public double[] minMax(int l, int u) { * @param m The MapToData to set the value part of the Map from * @param i The index to set in m */ - public void setM(Map map, AMapToData m, int i) { - m.set(i, map.get(getInternal(i)).intValue() - 1); + public void setM(HashMapToInt map, AMapToData m, int i) { + m.set(i, map.getI(getInternal(i)) - 1); } /** @@ -1053,17 +1053,17 @@ public void setM(Map map, AMapToData m, int i) { * @param m The MapToData to set the value part of the Map from * @param i The index to set in m */ - public void setM(Map map, int si, AMapToData m, int i) { - try { - final T v = getInternal(i); - if(v != null) - m.set(i, map.get(v).intValue() - 1); - else - m.set(i, si); - } - catch(Exception e) { - String error = "expected: " + getInternal(i) + " to be in map: " + map; - throw new RuntimeException(error, e); - } + public void setM(HashMapToInt map, int si, AMapToData m, int i) { + // try { + final T v = getInternal(i); + if(v != null) + m.set(i, map.getI(v) - 1); + else + m.set(i, si); + // } + // catch(Exception e) { + // String error = "expected: " + getInternal(i) + " to be in map: " + map; + // throw new RuntimeException(error, e); + // } } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java index 40013829637..69bfa38e7ff 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java @@ -178,7 +178,7 @@ public static Array compressToDDC(Array arr, int estimateUnique) { } @Override - protected Map createRecodeMap(int estimate, ExecutorService pool, int k) + protected HashMapToInt createRecodeMap(int estimate, ExecutorService pool, int k) throws InterruptedException, ExecutionException { return dict.createRecodeMap(estimate, pool, k); } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashIntegerArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashIntegerArray.java index 328c9a565fe..059b53e9d9e 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashIntegerArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashIntegerArray.java @@ -23,7 +23,6 @@ import java.io.DataOutput; import java.io.IOException; import java.util.Arrays; -import java.util.Map; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.common.Types.ValueType; @@ -434,7 +433,7 @@ public boolean possiblyContainsNaN() { } @Override - protected int addValRecodeMap(Map map, int id, int i) { + protected int addValRecodeMap(HashMapToInt map, int id, int i) { Integer val = Integer.valueOf(getInt(i)); Integer v = map.putIfAbsent(val, id); if(v == null) @@ -443,12 +442,12 @@ protected int addValRecodeMap(Map map, int id, int i) { } @Override - public void setM(Map map, AMapToData m, int i) { + public void setM(HashMapToInt map, AMapToData m, int i) { m.set(i, map.get(Integer.valueOf(getInt(i))).intValue() - 1); } @Override - public void setM(Map map, int si, AMapToData m, int i) { + public void setM(HashMapToInt map, int si, AMapToData m, int i) { final Integer v = Integer.valueOf(getInt(i)); m.set(i, map.get(v).intValue() - 1); } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java index 8fd308951e4..7d9448fc9c5 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashLongArray.java @@ -23,7 +23,6 @@ import java.io.DataOutput; import java.io.IOException; import java.util.Arrays; -import java.util.Map; import org.apache.commons.lang3.NotImplementedException; import org.apache.sysds.common.Types.ValueType; @@ -430,22 +429,12 @@ public boolean possiblyContainsNaN() { } @Override - protected int addValRecodeMap(Map map, int id, int i) { - Long val = Long.valueOf(getLong(i)); - Integer v = map.putIfAbsent(val, id); - if(v == null) - id++; - - return id; - } - - @Override - public void setM(Map map, AMapToData m, int i) { + public void setM(HashMapToInt map, AMapToData m, int i) { m.set(i, map.get(Long.valueOf(getLong(i))) - 1); } @Override - public void setM(Map map, int si, AMapToData m, int i) { + public void setM(HashMapToInt map, int si, AMapToData m, int i) { m.set(i, map.get(Long.valueOf(getLong(i))) - 1); } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java new file mode 100644 index 00000000000..bbc92ec7a8e --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.frame.data.columns; + +import java.io.Serializable; +import java.util.Collection; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; + +public class HashMapToInt implements Map, Serializable, Cloneable { + + private static final long serialVersionUID = 3624988207265L; + static final int DEFAULT_INITIAL_CAPACITY = 1 << 4; + static final int MAXIMUM_CAPACITY = 1 << 30; + static final float DEFAULT_LOAD_FACTOR = 0.75f; + + + static class Node { + final K key; + int value; + Node next; + + Node( K key, int value, Node next) { + this.key = key; + this.value = value; + this.next = next; + } + + public final void setNext(Node n) { + next = n; + } + } + + protected Node[] buckets; + int size; + // protected List> keys; + // protected int[][] values; + + public HashMapToInt(int capacity) { + alloc(Math.max(capacity, 16)); + } + + + + + @SuppressWarnings({"unchecked"}) + protected void alloc(int size) { + Node[] tmp = (Node[])new Node[size]; + buckets = tmp; + } + + @Override + public int size() { + return size; + } + + @Override + public boolean isEmpty() { + return size == 0; + } + + @Override + @SuppressWarnings({"unchecked"}) + public boolean containsKey(Object key) { + return get((K) key) != -1; + } + + @Override + public boolean containsValue(Object value) { + throw new UnsupportedOperationException("Unimplemented method 'containsValue'"); + } + + @Override + @SuppressWarnings({"unchecked"}) + public Integer get(Object key) { + final int i = getI((K) key); + if(i != -1) + return i; + else + return null; + } + + public int getI(K key) { + final int ix = hash(key); + Node b = buckets[ix]; + if(b != null) { + do{ + if(b.key.equals(key)) + return b.value; + } while((b = b.next) != null); + } + return -1; + } + + public int hash(K key){ + return Math.abs(key.hashCode()) % buckets.length; + } + + @Override + public Integer put(K key, Integer value) { + int i = putI(key, value); + if(i != -1) + return i; + else + return null; + } + + @Override + public Integer putIfAbsent(K key, Integer value){ + int i = putIfAbsentI(key, value); + if(i != -1) + return i; + else + return null; + } + + public int putIfAbsentI(K key, int value){ + final int ix = hash(key); + Node b = buckets[ix]; + if( b == null) + return createBucket(ix, key, value); + else + return putIfAbsentBucket(ix, key, value); + } + + + private int putIfAbsentBucket(int ix, K key, int value) { + Node b = buckets[ix]; + while(true){ + if(b.key.equals(key)) + return b.value; + if(b.next == null){ + b.next = new Node<>(key, value, null); + size++; + return -1; + } + b = b.next; + } + } + + public int putI(K key, int value) { + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucket(ix, key, value); + else + return addToBucket(ix, key, value); + } + + private int createBucket(int ix, K key, int value) { + buckets[ix] = new Node(key, value, null ); + size++; + return -1; + } + + private int addToBucket(int ix, K key, int value) { + Node b = buckets[ix]; + while(true){ + + if(b.key.equals(key)){ + int tmp = b.value; + b.value = value; + return tmp; + } + if(b.next == null){ + b.next = new Node<>(key, value, null); + size++; + return -1; + } + b = b.next; + } + } + + @Override + public Integer remove(Object key) { + throw new UnsupportedOperationException("Unimplemented method 'remove'"); + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException("Unimplemented method 'putAll'"); + } + + @Override + public void clear() { + throw new UnsupportedOperationException("Unimplemented method 'clear'"); + } + + @Override + public Set keySet() { + throw new UnsupportedOperationException("Unimplemented method 'keySet'"); + } + + @Override + public Collection values() { + throw new UnsupportedOperationException("Unimplemented method 'values'"); + } + + @Override + public Set> entrySet() { + throw new UnsupportedOperationException("Unimplemented method 'entrySet'"); + } + + @Override + public void forEach(BiConsumer action) { + for(Node n : buckets){ + if(n != null){ + do{ + action.accept(n.key, n.value); + } + while((n = n.next) != null); + } + } + } + + @Override + public String toString(){ + StringBuilder sb = new StringBuilder(); + this.forEach((k,v) -> { + sb.append("("+k +"→" + v+")"); + }); + return sb.toString(); + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java index 7a698dbd72f..bb3ad639327 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/IntegerArray.java @@ -378,6 +378,13 @@ public boolean possiblyContainsNaN() { return false; } + @Override + protected int addValRecodeMap(HashMapToInt map, int id, int i) { + if( map.putIfAbsentI(_data[i], id) == -1) + id++; + return id; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_size * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java index dd0fca6cdfb..91cc5f2bead 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/OptionalArray.java @@ -22,7 +22,6 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; -import java.util.Map; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; @@ -468,12 +467,12 @@ public boolean possiblyContainsNaN() { } @Override - public void setM(Map map, AMapToData m, int i) { + public void setM(HashMapToInt map, AMapToData m, int i) { _a.setM(map, m, i); } @Override - public void setM(Map map, int si, AMapToData m, int i) { + public void setM(HashMapToInt map, int si, AMapToData m, int i) { if(_n.get(i)) _a.setM(map, si, m, i); else @@ -481,7 +480,7 @@ public void setM(Map map, int si, AMapToData m, int i) { } @Override - protected int addValRecodeMap(Map map, int id, int i) { + protected int addValRecodeMap(HashMapToInt map, int id, int i) { if(_n.get(i)) id = _a.addValRecodeMap(map, id, i); return id; diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index 601fc78b2fc..1fc582924e4 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -23,8 +23,6 @@ import java.io.DataOutput; import java.io.IOException; import java.util.Arrays; -import java.util.HashMap; -import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -669,9 +667,9 @@ public final boolean isNotEmpty(int i) { } @Override - protected Map createRecodeMap(int estimate, ExecutorService pool, int k) throws InterruptedException, ExecutionException { + protected HashMapToInt createRecodeMap(int estimate, ExecutorService pool, int k) throws InterruptedException, ExecutionException { try { - Map map = new HashMap<>((int) Math.min((long) estimate * 2, size())); + HashMapToInt map = new HashMapToInt((int) Math.min((long) estimate * 2, size())); for(int i = 0; i < size(); i++) { Object val = get(i); if(val != null) { diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java index e784086427d..54f6bd4003a 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java @@ -59,7 +59,7 @@ public ColumnEncoderRecode() { this(-1); } - protected ColumnEncoderRecode(int colID, HashMap rcdMap) { + protected ColumnEncoderRecode(int colID, Map rcdMap) { super(colID); _rcdMap = rcdMap; } @@ -304,16 +304,27 @@ public FrameBlock getMetaData(FrameBlock meta) { // create compact meta data representation StringBuilder sb = new StringBuilder(); // for reuse - int rowID = 0; - for(Entry e : _rcdMap.entrySet()) { - meta.set(rowID++, _colID - 1, // 1-based - constructRecodeMapEntry(e.getKey(), e.getValue(), sb)); - } - meta.getColumnMetadata(_colID - 1).setNumDistinct(getNumDistinctValues()); + final Inc rowID = new Inc(); + + final int colIDCorrected = _colID - 1; + _rcdMap.forEach( (k,v) -> { + meta.set(rowID.i(), colIDCorrected, // 1-based + constructRecodeMapEntry(k, v, sb)); + }); + // for(Entry e : _rcdMap.entrySet()) { + // } + meta.getColumnMetadata(colIDCorrected).setNumDistinct(getNumDistinctValues()); return meta; } + private static class Inc{ + int i = 0; + public int i(){ + return i++; + } + } + /** * Construct the recodemaps from the given input frame for all columns registered for recode. * diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 5ade585dbb6..893f06ea45e 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -52,6 +52,7 @@ import org.apache.sysds.runtime.frame.data.columns.ACompressedArray; import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.frame.data.columns.DDCArray; +import org.apache.sysds.runtime.frame.data.columns.HashMapToInt; import org.apache.sysds.runtime.frame.data.compress.ArrayCompressionStatistics; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin.BinMethod; @@ -184,16 +185,18 @@ private AColGroup recodeToDummy(ColumnEncoderComposite c) throws Exception { Array a = (Array) in.getColumn(colId - 1); boolean containsNull = a.containsNull(); estimateRCDMapSize(c); - Map map = a.getRecodeMap(c._estNumDistincts, pool, k / in.getNumColumns() ); + HashMapToInt map = (HashMapToInt) a.getRecodeMap(c._estNumDistincts, pool, k / in.getNumColumns()); List r = c.getEncoders(); - r.set(0, new ColumnEncoderRecode(colId, (HashMap) map)); + r.set(0, new ColumnEncoderRecode(colId, (HashMapToInt) map)); int domain = map.size(); if(containsNull && domain == 0) return new ColGroupEmpty(ColIndexFactory.create(1)); IColIndex colIndexes = ColIndexFactory.create(0, domain); - if(domain == 1 && !containsNull) + if(domain == 1 && !containsNull){ + nnz.addAndGet(in.getNumRows()); return ColGroupConst.create(colIndexes, new double[] {1}); + } ADictionary d = new IdentityDictionary(colIndexes.size(), containsNull); AMapToData m = createMappingAMapToData(a, map, containsNull); @@ -338,14 +341,32 @@ private AColGroup recode(ColumnEncoderComposite c) throws Exception { int colId = c._colID; Array a = (Array) in.getColumn(colId - 1); estimateRCDMapSize(c); - Map map = a.getRecodeMap(c._estNumDistincts, pool, k / in.getNumColumns() ); + HashMapToInt map = (HashMapToInt) a.getRecodeMap(c._estNumDistincts, pool, k / in.getNumColumns()); boolean containsNull = a.containsNull(); int domain = map.size(); // int domain = c.getDomainSize(); IColIndex colIndexes = ColIndexFactory.create(1); - if(domain == 1) + if(domain == 0 && containsNull){ + return new ColGroupEmpty(colIndexes); + } + if(domain == 1 && !containsNull){ + nnz.addAndGet(in.getNumRows()); return ColGroupConst.create(colIndexes, new double[] {1}); + } + ADictionary d = createRecodeDictionary(containsNull, domain); + + AMapToData m = createMappingAMapToData(a, map, containsNull); + + List r = c.getEncoders(); + r.set(0, new ColumnEncoderRecode(colId, (HashMapToInt) map)); + AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); + + nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); + return ret; + } + + private ADictionary createRecodeDictionary(boolean containsNull, int domain) { MatrixBlock incrementing = new MatrixBlock(domain + (containsNull ? 1 : 0), 1, false); for(int i = 0; i < domain; i++) incrementing.set(i, 0, i + 1); @@ -353,14 +374,7 @@ private AColGroup recode(ColumnEncoderComposite c) throws Exception { incrementing.set(domain, 0, Double.NaN); ADictionary d = MatrixBlockDictionary.create(incrementing); - - AMapToData m = createMappingAMapToData(a, map, containsNull); - - List r = c.getEncoders(); - r.set(0, new ColumnEncoderRecode(colId, (HashMap) map)); - AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); - nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); - return ret; + return d; } @SuppressWarnings("unchecked") @@ -395,7 +409,7 @@ private AColGroup passThroughNormal(ColumnEncoderComposite c, final IColInde else { boolean containsNull = a.containsNull(); estimateRCDMapSize(c); - Map map = a.getRecodeMap(c._estNumDistincts, pool, k / in.getNumColumns() ); + HashMapToInt map = (HashMapToInt) a.getRecodeMap(c._estNumDistincts, pool, k / in.getNumColumns()); double[] vals = new double[map.size() + (containsNull ? 1 : 0)]; if(containsNull) vals[map.size()] = Double.NaN; @@ -428,7 +442,7 @@ private AColGroup passThroughCompressed(final IColIndex colIndexes, final Ar return ret; } - private AMapToData createMappingAMapToData(Array a, Map map, boolean containsNull) + private AMapToData createMappingAMapToData(Array a, HashMapToInt map, boolean containsNull) throws Exception { final int si = map.size(); final int nRow = in.getNumRows(); @@ -443,7 +457,7 @@ private AMapToData createMappingAMapToData(Array a, Map map, return createMappingSingleThread(a, map, containsNull, si, nRow, m); } - private AMapToData CreateMappingParallel(Array a, Map map, boolean containsNull, final int si, + private AMapToData CreateMappingParallel(Array a, HashMapToInt map, boolean containsNull, final int si, final int nRow, final AMapToData m) throws InterruptedException, ExecutionException { final int tk = k / in.getNumColumns(); final int blkz = Math.max(ROW_PARALLELIZATION_THRESHOLD / 2, (nRow + tk) / tk); @@ -469,7 +483,7 @@ private AMapToData CreateMappingParallel(Array a, Map map, bo } - private AMapToData createMappingSingleThread(Array a, Map map, boolean containsNull, final int si, + private AMapToData createMappingSingleThread(Array a, HashMapToInt map, boolean containsNull, final int si, final int nRow, final AMapToData m) { if(containsNull) return createMappingAMapToDataWithNull(a, map, si, m, 0, nRow); @@ -477,14 +491,14 @@ private AMapToData createMappingSingleThread(Array a, Map map return createMappingAMapToDataNoNull(a, map, m, 0, nRow); } - private static AMapToData createMappingAMapToDataNoNull(Array a, Map map, AMapToData m, int start, + private static AMapToData createMappingAMapToDataNoNull(Array a, HashMapToInt map, AMapToData m, int start, int end) { for(int i = start; i < end; i++) a.setM(map, m, i); return m; } - private static AMapToData createMappingAMapToDataWithNull(Array a, Map map, int si, AMapToData m, + private static AMapToData createMappingAMapToDataWithNull(Array a, HashMapToInt map, int si, AMapToData m, int start, int end) { for(int i = start; i < end; i++) a.setM(map, si, m, i); @@ -518,8 +532,13 @@ private AColGroup hash(ColumnEncoderComposite c) { int domain = (int) CEHash.getK(); boolean nulls = a.containsNull(); IColIndex colIndexes = ColIndexFactory.create(0, 1); - if(domain == 1 && !nulls) + if(domain == 0 && nulls){ + return new ColGroupEmpty(colIndexes); + } + if(domain == 1 && !nulls){ + nnz.addAndGet(in.getNumRows()); return ColGroupConst.create(colIndexes, new double[] {1}); + } MatrixBlock incrementing = new MatrixBlock(domain + (nulls ? 1 : 0), 1, false); for(int i = 0; i < domain; i++) @@ -542,8 +561,13 @@ private AColGroup hashToDummy(ColumnEncoderComposite c) { int domain = (int) CEHash.getK(); boolean nulls = a.containsNull(); IColIndex colIndexes = ColIndexFactory.create(0, domain); - if(domain == 1 && !nulls) + if(domain == 0 && nulls){ + return new ColGroupEmpty(ColIndexFactory.create(1)); + } + if(domain == 1 && !nulls){ + nnz.addAndGet(in.getNumRows()); return ColGroupConst.create(colIndexes, new double[] {1}); + } ADictionary d = new IdentityDictionary(colIndexes.size(), nulls); AMapToData m = createHashMappingAMapToData(a, domain, nulls); AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java index 7cafa0e437a..5d52671804a 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java @@ -161,7 +161,9 @@ public MatrixBlock encode(CacheBlock in, int k, boolean compressedOut){ } } catch(Exception ex) { - throw new DMLRuntimeException("Failed transform-encode frame with encoder:\n" + this, ex); + String st = this.toString(); + st = st.substring(0, Math.min(st.length(), 1000)); + throw new DMLRuntimeException("Failed transform-encode frame with encoder:\n" + st, ex); } } diff --git a/src/test/java/org/apache/sysds/performance/frame/Transform.java b/src/test/java/org/apache/sysds/performance/frame/Transform.java index c6159a1a329..806e40f5919 100644 --- a/src/test/java/org/apache/sysds/performance/frame/Transform.java +++ b/src/test/java/org/apache/sysds/performance/frame/Transform.java @@ -27,7 +27,6 @@ import org.apache.sysds.performance.generators.IGenerate; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; -import org.apache.sysds.runtime.frame.data.lib.FrameLibCompress; import org.apache.sysds.runtime.transform.encode.EncoderFactory; import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; import org.apache.sysds.test.TestUtils; @@ -49,7 +48,7 @@ public Transform(int N, IGenerate gen, int k, String spec) { } public void run() throws Exception { - // execute(() -> te(), () -> clear(), "Normal"); + execute(() -> te(), () -> clear(), "Normal"); execute(() -> tec(), () -> clear(), "Compressed"); } From f16abe9fdf3917560b88da6ec04c3fa4dfda15ce Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 18:34:19 +0100 Subject: [PATCH 13/81] fix merge Recode --- .../org/apache/sysds/runtime/frame/data/columns/Array.java | 7 ++++--- src/test/java/org/apache/sysds/performance/README.md | 7 +++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index d34b2ad6153..bfc20de3649 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -205,10 +205,11 @@ private HashMapToInt parallelCreateRecodeMap(int estimate, ExecutorService po * @param target The target object to merge the two maps into * @param from The Map to take entries from. */ - protected static void mergeRecodeMaps(Map target, Map from) { + protected static void mergeRecodeMaps(HashMapToInt target, HashMapToInt from) { final List fromEntriesOrdered = new ArrayList<>(Collections.nCopies(from.size(), null)); - for(Map.Entry e : from.entrySet()) - fromEntriesOrdered.set(e.getValue() - 1, e.getKey()); + from.forEach((k,v) -> { + fromEntriesOrdered.set(v - 1, k); + }); int id = target.size(); for(T e : fromEntriesOrdered) { if(target.putIfAbsent(e, id) == null) diff --git a/src/test/java/org/apache/sysds/performance/README.md b/src/test/java/org/apache/sysds/performance/README.md index 899931f5131..20d2757c805 100644 --- a/src/test/java/org/apache/sysds/performance/README.md +++ b/src/test/java/org/apache/sysds/performance/README.md @@ -82,3 +82,10 @@ Binary Operations ```bash java -jar -agentpath:$HOME/Programs/profiler/lib/libasyncProfiler.so=start,event=cpu,file=temp/log.html -XX:+UseNUMA target/systemds-3.3.0-SNAPSHOT-perf.jar 1006 500 ``` + + +transform encode + +```bash +java -jar -agentpath:$HOME/Programs/profiler/lib/libasyncProfiler.so=start,event=cpu,file=temp/log.html -XX:+UseNUMA target/systemds-3.3.0-SNAPSHOT-perf.jar 1007 +``` From fb325a0479b252d394f26b3a886b404a644c0811 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 19:07:22 +0100 Subject: [PATCH 14/81] only compressed --- src/test/java/org/apache/sysds/performance/frame/Transform.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/java/org/apache/sysds/performance/frame/Transform.java b/src/test/java/org/apache/sysds/performance/frame/Transform.java index 806e40f5919..27d4acad255 100644 --- a/src/test/java/org/apache/sysds/performance/frame/Transform.java +++ b/src/test/java/org/apache/sysds/performance/frame/Transform.java @@ -48,7 +48,7 @@ public Transform(int N, IGenerate gen, int k, String spec) { } public void run() throws Exception { - execute(() -> te(), () -> clear(), "Normal"); + // execute(() -> te(), () -> clear(), "Normal"); execute(() -> tec(), () -> clear(), "Compressed"); } From 01a06a581f085da2b6bc49c3f05941e9b34fcfdf Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 19:15:30 +0100 Subject: [PATCH 15/81] fix recodeMapTest --- .../sysds/test/component/frame/array/RecodeMapTest.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/RecodeMapTest.java b/src/test/java/org/apache/sysds/test/component/frame/array/RecodeMapTest.java index 1d8f8fdd7c2..4e0d0440d56 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/RecodeMapTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/RecodeMapTest.java @@ -91,9 +91,9 @@ public void createRecodeMapParallel2() throws Exception { assertTrue(rcm.size() == rcm2.size()); - for(String k : rcm.keySet()){ + rcm.forEach((k,v) ->{ assertEquals(rcm.get(k), rcm2.get(k)); - } + }); final List log = LoggingUtils.reinsert(appender); assertTrue(log.size() >= 1); } From 48135da25d43611f582219d5062cab3301dd8eb8 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 19:21:23 +0100 Subject: [PATCH 16/81] another minor fix --- .../transform/encode/ColumnEncoderRecode.java | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java index 54f6bd4003a..e848927184a 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java @@ -28,7 +28,6 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Map; -import java.util.Map.Entry; import java.util.Objects; import java.util.concurrent.Callable; @@ -342,10 +341,15 @@ public void writeExternal(ObjectOutput out) throws IOException { super.writeExternal(out); out.writeInt(_rcdMap.size()); - for(Entry e : _rcdMap.entrySet()) { - out.writeUTF(e.getKey().toString()); - out.writeInt(e.getValue()); - } + _rcdMap.forEach((k, v)-> { + try{ + out.writeUTF(k.toString()); + out.writeInt(v); + } + catch(Exception e){ + throw new RuntimeException(e); + } + }); } @Override From 487585a936c9ce5d397f24a01a5955c48d4a616c Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 19:45:14 +0100 Subject: [PATCH 17/81] unsafe --- .../apache/sysds/runtime/frame/data/columns/BitSetArray.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java index 4ddc66e8de3..4e79566f2ee 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java @@ -107,7 +107,7 @@ public synchronized void set(int index, boolean value) { @Override public void setNullsFromString(int rl, int ru, Array value) { - final boolean unsafe = ru % 64 != 0 || rl % 64 != 0; + final boolean unsafe = ru % 64 != 63 || rl % 64 != 0; // ensure that it is safe to modify the values in the ranges. if(unsafe) { From 4f32ca9cc0ef6c5fab58076ef175a60bdca87f0a Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 20:03:00 +0100 Subject: [PATCH 18/81] unsafe --- .../runtime/compress/colgroup/mapping/MapToCharPByte.java | 3 +-- .../apache/sysds/runtime/frame/data/columns/BitSetArray.java | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java index 8c6fb6366ea..a5ce1e79285 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java @@ -142,8 +142,7 @@ public void write(DataOutput out) throws IOException { out.writeInt(_data_c.length); for(int i = 0; i < _data_c.length; i++) out.writeChar(_data_c[i]); - for(int i = 0; i < _data_c.length; i++) - out.writeByte(_data_b[i]); + out.write(_data_b); } protected static MapToCharPByte readFields(DataInput in) throws IOException { diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java index 4e79566f2ee..4ddc66e8de3 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java @@ -107,7 +107,7 @@ public synchronized void set(int index, boolean value) { @Override public void setNullsFromString(int rl, int ru, Array value) { - final boolean unsafe = ru % 64 != 63 || rl % 64 != 0; + final boolean unsafe = ru % 64 != 0 || rl % 64 != 0; // ensure that it is safe to modify the values in the ranges. if(unsafe) { From b859d4b4087b114b46517b7d19c729363a79855e Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 20:05:19 +0100 Subject: [PATCH 19/81] unsafe and safe --- .../frame/data/columns/BitSetArray.java | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java index 4ddc66e8de3..094d07c0dd4 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/BitSetArray.java @@ -107,26 +107,26 @@ public synchronized void set(int index, boolean value) { @Override public void setNullsFromString(int rl, int ru, Array value) { - final boolean unsafe = ru % 64 != 0 || rl % 64 != 0; + // final boolean unsafe = ru % 64 != 0 || rl % 64 != 0; // ensure that it is safe to modify the values in the ranges. - if(unsafe) { - // find rl rounded up to start safe - final int rl64 = Math.min((rl / 64 + 1) * 64, ru); - final int ru64 = (ru / 64) * 64; - - for(int i = rl; i < rl64; i++) - unsafeSet(i, value.get(i) != null); - for(int i = rl64; i < ru64; i++) - set(i, value.get(i) != null); - for(int i = ru64; i < ru; i++) - unsafeSet(i, value.get(i) != null); - } - else { - // safe all the way - for(int i = rl; i < ru; i++) - set(i, value.get(i) != null); - } + // if(unsafe) { + // find rl rounded up to start safe + final int rl64 = Math.min((rl / 64 + 1) * 64, ru); + final int ru64 = (ru / 64) * 64; + + for(int i = rl; i < rl64; i++) + set(i, value.get(i) != null); + for(int i = rl64; i < ru64; i++) + unsafeSet(i, value.get(i) != null); + for(int i = ru64; i < ru; i++) + set(i, value.get(i) != null); + // } + // else { + // // safe all the way + // for(int i = rl; i < ru; i++) + // set(i, value.get(i) != null); + // } } private void unsafeSet(int index, boolean value) { From 058187f97786e564b9b80eea3a89e4bc85956d20 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 20:12:58 +0100 Subject: [PATCH 20/81] chars buff --- .../compress/colgroup/mapping/MapToChar.java | 38 ++++++++++++------- .../colgroup/mapping/MapToCharPByte.java | 3 +- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java index 5982157726d..cb08474ee2f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java @@ -135,27 +135,37 @@ public void write(DataOutput out) throws IOException { out.writeByte(MAP_TYPE.CHAR.ordinal()); out.writeInt(getUnique()); out.writeInt(_data.length); + writeChars(out, _data); + + } + + + protected static void writeChars(DataOutput out, char[] _data_c) throws IOException { final int BS = 100; - if(_data.length > BS) { + if(_data_c.length > BS) { final byte[] buff = new byte[BS * 2]; - for(int i = 0; i < _data.length;) { - if(i + BS <= _data.length) { - for(int o = 0; o < BS; o++) { - IOUtilFunctions.shortToBa(_data[i++], buff, o * 2); - } - out.write(buff); - } - else {// remaining. - for(; i < _data.length; i++) - out.writeChar(_data[i]); - } + for(int i = 0; i < _data_c.length;) { + i = writeCharsBlock(out,_data_c, BS, buff, i); } } else { - for(int i = 0; i < _data.length; i++) - out.writeChar(_data[i]); + for(int i = 0; i < _data_c.length; i++) + out.writeChar(_data_c[i]); } + } + private static int writeCharsBlock(DataOutput out, char[] _data_c, final int BS, final byte[] buff, int i) throws IOException { + if(i + BS <= _data_c.length) { + for(int o = 0; o < BS; o++) { + IOUtilFunctions.shortToBa(_data_c[i++], buff, o * 2); + } + out.write(buff); + } + else {// remaining. + for(; i < _data_c.length; i++) + out.writeChar(_data_c[i]); + } + return i; } protected static MapToChar readFields(DataInput in) throws IOException { diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java index a5ce1e79285..9c679967d92 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java @@ -140,8 +140,7 @@ public void write(DataOutput out) throws IOException { out.writeByte(MAP_TYPE.CHAR_BYTE.ordinal()); out.writeInt(getUnique()); out.writeInt(_data_c.length); - for(int i = 0; i < _data_c.length; i++) - out.writeChar(_data_c[i]); + MapToChar.writeChars(out, _data_c); out.write(_data_b); } From 67ad0d8a7fc3d84eb46e5f95c327bc8f68aeaf91 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 20:52:53 +0100 Subject: [PATCH 21/81] map to char refine --- .../runtime/compress/colgroup/mapping/MapToChar.java | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java index cb08474ee2f..df8917a9393 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java @@ -139,13 +139,12 @@ public void write(DataOutput out) throws IOException { } - protected static void writeChars(DataOutput out, char[] _data_c) throws IOException { final int BS = 100; if(_data_c.length > BS) { final byte[] buff = new byte[BS * 2]; for(int i = 0; i < _data_c.length;) { - i = writeCharsBlock(out,_data_c, BS, buff, i); + i = writeCharsBlock(out, _data_c, BS, buff, i); } } else { @@ -154,7 +153,8 @@ protected static void writeChars(DataOutput out, char[] _data_c) throws IOExcept } } - private static int writeCharsBlock(DataOutput out, char[] _data_c, final int BS, final byte[] buff, int i) throws IOException { + private static int writeCharsBlock(DataOutput out, char[] _data_c, final int BS, final byte[] buff, int i) + throws IOException { if(i + BS <= _data_c.length) { for(int o = 0; o < BS; o++) { IOUtilFunctions.shortToBa(_data_c[i++], buff, o * 2); @@ -260,7 +260,6 @@ private void getCountsBy8P(int[] ret, int s, int e) { } } - @Override public AMapToData resize(int unique) { final int size = _data.length; @@ -336,7 +335,6 @@ public AMapToData appendN(IMapToDataGroup[] d) { return new MapToChar(getUnique(), ret); } - @Override public boolean equals(AMapToData e) { return e instanceof MapToChar && // From 3683378951425b9dfc008ca72c434e6d1156a9f3 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 20:54:40 +0100 Subject: [PATCH 22/81] restore log output --- .../runtime/transform/encode/MultiColumnEncoder.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java index 5d52671804a..1ee068f8089 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java @@ -157,6 +157,14 @@ public MatrixBlock encode(CacheBlock in, int k, boolean compressedOut){ MatrixBlock out = apply(in, k); t1 = System.nanoTime(); LOG.debug("Elapsed time for apply phase: "+ ((double) t1 - t0) / 1000000 + " ms"); + + if(LOG.isDebugEnabled()) { + LOG.debug("Transform Encode output mem size: " + out.getInMemorySize()); + LOG.debug(String.format("Transform Encode output rows : %10d", out.getNumRows())); + LOG.debug(String.format("Transform Encode output cols : %10d", out.getNumColumns())); + LOG.debug(String.format("Transform Encode output sparsity : %10.5f", out.getSparsity())); + LOG.debug(String.format("Transform Encode output nnz : %10d", out.getNonZeros())); + } return out; } } From b6933d4a3671be579e0b07732432b7556fc33e8c Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 21:07:16 +0100 Subject: [PATCH 23/81] otherBranch Logging --- .../transform/encode/MultiColumnEncoder.java | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java index 1ee068f8089..10efd9b9f43 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java @@ -136,6 +136,7 @@ public MatrixBlock encode(CacheBlock in, int k, boolean compressedOut){ pool.shutdown(); } outputMatrixPostProcessing(out, k); + outputLogging(out); return out; } else { @@ -158,13 +159,7 @@ public MatrixBlock encode(CacheBlock in, int k, boolean compressedOut){ t1 = System.nanoTime(); LOG.debug("Elapsed time for apply phase: "+ ((double) t1 - t0) / 1000000 + " ms"); - if(LOG.isDebugEnabled()) { - LOG.debug("Transform Encode output mem size: " + out.getInMemorySize()); - LOG.debug(String.format("Transform Encode output rows : %10d", out.getNumRows())); - LOG.debug(String.format("Transform Encode output cols : %10d", out.getNumColumns())); - LOG.debug(String.format("Transform Encode output sparsity : %10.5f", out.getSparsity())); - LOG.debug(String.format("Transform Encode output nnz : %10d", out.getNonZeros())); - } + outputLogging(out); return out; } } @@ -175,6 +170,16 @@ public MatrixBlock encode(CacheBlock in, int k, boolean compressedOut){ } } + private void outputLogging(MatrixBlock out) { + if(LOG.isDebugEnabled()) { + LOG.debug("Transform Encode output mem size: " + out.getInMemorySize()); + LOG.debug(String.format("Transform Encode output rows : %10d", out.getNumRows())); + LOG.debug(String.format("Transform Encode output cols : %10d", out.getNumColumns())); + LOG.debug(String.format("Transform Encode output sparsity : %10.5f", out.getSparsity())); + LOG.debug(String.format("Transform Encode output nnz : %10d", out.getNonZeros())); + } + } + protected List getEncoders() { return _columnEncoders; } From 0f45fe052ddcd5f4cc7f20b97432125583f3dad2 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 21:33:18 +0100 Subject: [PATCH 24/81] not really a change --- .../java/org/apache/sysds/runtime/frame/data/FrameBlock.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java index 7566ba2fd55..f8c90fcdeec 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/FrameBlock.java @@ -769,7 +769,7 @@ public void write(DataOutput out) throws IOException { out.writeUTF(getColumnName(j)); _colmeta[j].write(out); } - if(type >= 0 && nRow > 0) // if allocated write column data + if(type > 0 && nRow > 0) // if allocated write column data _coldata[j].write(out); } } @@ -910,7 +910,7 @@ public long getExactSerializedSize() { size += IOUtilFunctions.getUTFSize(getColumnName(j)); size += _colmeta[j].getExactSerializedSize(); } - if(type >= 0) + if(type > 0) size += _coldata[j].getExactSerializedSize(); } return size; From 6744bf8b45902db9db60859190a6299dab0274f9 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 21:35:46 +0100 Subject: [PATCH 25/81] Binary readers update --- .../runtime/io/FrameReaderBinaryBlock.java | 31 +++++++++ .../runtime/io/FrameWriterBinaryBlock.java | 67 ++++++++++++++++-- .../sysds/runtime/io/IOUtilFunctions.java | 69 +++++++++++++++---- 3 files changed, 151 insertions(+), 16 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameReaderBinaryBlock.java b/src/main/java/org/apache/sysds/runtime/io/FrameReaderBinaryBlock.java index 7625737d9f8..f2244aeb7e1 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameReaderBinaryBlock.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameReaderBinaryBlock.java @@ -31,6 +31,8 @@ import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.ArrayWrapper; +import org.apache.sysds.runtime.frame.data.columns.DDCArray; /** * Single-threaded frame binary block reader. @@ -58,6 +60,9 @@ public final FrameBlock readFrameFromHDFS(String fname, ValueType[] schema, Stri // core read (sequential/parallel) readBinaryBlockFrameFromHDFS(path, job, fs, ret, rlen, clen); + + readBinaryDictionariesFromHDFS(new Path(fname + ".dict"), job, fs, ret); + return ret; } @@ -114,6 +119,29 @@ protected static void readBinaryBlockFrameFromSequenceFile(Path path, JobConf jo } } + protected static void readBinaryDictionariesFromHDFS(Path path, JobConf job, FileSystem fs, FrameBlock ret) { + try{ + if(fs.exists(path)){ + LongWritable key = new LongWritable(); + ArrayWrapper value = new ArrayWrapper(null); + SequenceFile.Reader reader = new SequenceFile.Reader(job, SequenceFile.Reader.file(path)); + try{ + while(reader.next(key,value)){ + int colId = (int)key.get(); + DDCArray a = (DDCArray) ret.getColumn(colId); + ret.setColumn(colId, a.setDict(value._a)); + } + } + finally{ + IOUtilFunctions.closeSilently(reader); + } + } + } + catch(IOException e){ + throw new DMLRuntimeException("Failed to read Frame Dictionaries", e); + } + } + /** * Specific functionality of FrameReaderBinaryBlock, mostly used for testing. * @@ -143,4 +171,7 @@ public FrameBlock readFirstBlock(String fname) throws IOException { return value; } + + + } diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameWriterBinaryBlock.java b/src/main/java/org/apache/sysds/runtime/io/FrameWriterBinaryBlock.java index 859cbe028c2..b72661ba3ba 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameWriterBinaryBlock.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameWriterBinaryBlock.java @@ -20,6 +20,8 @@ package org.apache.sysds.runtime.io; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; @@ -29,6 +31,10 @@ import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ArrayWrapper; +import org.apache.sysds.runtime.frame.data.columns.DDCArray; +import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.util.HDFSTool; /** @@ -43,30 +49,67 @@ public final void writeFrameToHDFS(FrameBlock src, String fname, long rlen, long // prepare file access JobConf job = new JobConf(ConfigurationManager.getCachedJobConf()); Path path = new Path(fname); - + // if the file already exists on HDFS, remove it. HDFSTool.deleteFileIfExistOnHDFS(fname); - + HDFSTool.deleteFileIfExistOnHDFS(fname + ".dict"); + // bound check for src block if(src.getNumRows() > rlen || src.getNumColumns() > clen) { throw new IOException("Frame block [1:" + src.getNumRows() + ",1:" + src.getNumColumns() + "] " + "out of overall frame range [1:" + rlen + ",1:" + clen + "]."); } + Pair>>, FrameBlock> prep = extractDictionaries(src); + src = prep.getValue(); + // write binary block to hdfs (sequential/parallel) - writeBinaryBlockFrameToHDFS(path, job, src, rlen, clen); + writeBinaryBlockFrameToHDFS(path, job, prep.getValue(), rlen, clen); + + if(prep.getKey().size() > 0) + writeBinaryBlockDictsToSequenceFile(new Path(fname + ".dict"), job, prep.getKey()); + + } + + protected Pair>>, FrameBlock> extractDictionaries(FrameBlock src){ + List>> dicts = new ArrayList<>(); + int blen = ConfigurationManager.getBlocksize(); + if(src.getNumRows() < blen ) + return new Pair<>(dicts, src); + boolean modified = false; + for(int i = 0; i < src.getNumColumns(); i++){ + Array a = src.getColumn(i); + if(a instanceof DDCArray){ + DDCArray d = (DDCArray)a; + dicts.add(new Pair<>(i, d.getDict())); + if(modified == false){ + modified = true; + // make sure other users of this frame does not get effected + src = src.copyShallow(); + } + src.setColumn(i, d.nullDict()); + } + } + return new Pair<>(dicts, src); } protected void writeBinaryBlockFrameToHDFS(Path path, JobConf job, FrameBlock src, long rlen, long clen) throws IOException, DMLRuntimeException { FileSystem fs = IOUtilFunctions.getFileSystem(path); int blen = ConfigurationManager.getBlocksize(); - + // sequential write to single file writeBinaryBlockFrameToSequenceFile(path, job, fs, src, blen, 0, (int) rlen); IOUtilFunctions.deleteCrcFilesFromLocalFileSystem(fs, path); } + protected void writeBinaryBlockDictsToSequenceFile(Path path, JobConf job, List>> dicts) + throws IOException, DMLRuntimeException { + FileSystem fs = IOUtilFunctions.getFileSystem(path); + writeBinaryBlockDictsToSequenceFile(path, job, fs, dicts); + IOUtilFunctions.deleteCrcFilesFromLocalFileSystem(fs, path); + } + /** * Internal primitive to write a block-aligned row range of a frame to a single sequence file, which is used for both * single- and multi-threaded writers (for consistency). @@ -111,4 +154,20 @@ protected static void writeBinaryBlockFrameToSequenceFile(Path path, JobConf job IOUtilFunctions.closeSilently(writer); } } + + protected static void writeBinaryBlockDictsToSequenceFile(Path path, JobConf job, FileSystem fs, List>> dicts) throws IOException{ + final Writer writer = IOUtilFunctions.getSeqWriterArray(path, job, 1); + try{ + LongWritable index = new LongWritable(); + + for(int i = 0; i < dicts.size(); i++){ + Pair> p = dicts.get(i); + index.set(p.getKey()); + writer.append(index, new ArrayWrapper(p.getValue())); + } + } + finally { + IOUtilFunctions.closeSilently(writer); + } + } } diff --git a/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java b/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java index 9aa948acf40..e4feb3ed756 100644 --- a/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java +++ b/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java @@ -72,10 +72,10 @@ import org.apache.sysds.runtime.data.TensorBlock; import org.apache.sysds.runtime.data.TensorIndexes; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.ArrayWrapper; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixCell; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; -import org.apache.sysds.runtime.transform.TfUtils; import org.apache.sysds.runtime.util.LocalFileUtils; import io.airlift.compress.lzo.LzoCodec; @@ -242,6 +242,29 @@ public static String[] splitCSV(String str, String delim){ return tokens.toArray(new String[0]); } + public static String[] splitCSV(String str, String delim, int clen){ + if(str == null || str.isEmpty()) + return new String[] {""}; + + int from = 0, to = 0; + final int len = str.length(); + final int delimLen = delim.length(); + + final String[] tokens = new String[clen]; + int c = 0; + while(from < len) { // for all tokens + to = getTo(str, from, delim, len, delimLen); + tokens[c++] = str.substring(from, to); + from = to + delimLen; + } + + // handle empty string at end + if(from == len) + tokens[c++] = ""; + + return tokens; + } + /** * Splits a string by a specified delimiter into all tokens, including empty * while respecting the rules for quotes and escapes defined in RFC4180, @@ -346,7 +369,7 @@ private static boolean isEmptyMatch(final String str, final int from, final Stri * @param dLen The length of the delimiter string * @return The next index. */ - private static int getTo(final String str, final int from, final String delim, + public static int getTo(final String str, final int from, final String delim, final int len, final int dLen) { final char cq = CSV_QUOTE_CHAR; final int fromP1 = from + 1; @@ -404,17 +427,32 @@ private static int getToNoQuoteCharDelim(final String str, final int from, final } public static String trim(String str) { + final int len = str.length(); + if(len == 0) + return str; + return trim(str, len); + } + + /** + * Caller must have a string of at least 1 character length. + * + * @param str string to trim + * @param len length of string + * @return the trimmed string. + */ + public static String trim(final String str, final int len) { try{ - final int len = str.length(); - if(len == 0) - return str; // short the call to return input if not whitespace in ends. - else if(str.charAt(0) <= ' ' || str.charAt(len -1) <= ' ') + if(str.charAt(0) <= ' ' || str.charAt(len -1) <= ' ') return str.trim(); else return str; - }catch(Exception e){ - throw new RuntimeException("failed trimming: " + str + " " + str.length(),e); + } + catch(NullPointerException e){ + return null; + } + catch(Exception e){ + throw new RuntimeException("failed trimming: " + str + " " + str.length(), e); } } @@ -657,10 +695,10 @@ public static int countNumColumnsCSV(InputSplit[] splits, InputFormat informat, try { if( reader.next(key, value) ) { boolean hasValue = true; - if( value.toString().startsWith(TfUtils.TXMTD_MVPREFIX) ) - hasValue = reader.next(key, value); - if( value.toString().startsWith(TfUtils.TXMTD_NDPREFIX) ) - hasValue = reader.next(key, value); + // if( value.toString().startsWith(TfUtils.TXMTD_MVPREFIX) ) + // hasValue = reader.next(key, value); + // if( value.toString().startsWith(TfUtils.TXMTD_NDPREFIX) ) + // hasValue = reader.next(key, value); String row = value.toString().trim(); if( hasValue && !row.isEmpty() ) { ncol = IOUtilFunctions.countTokensCSV(row, delim); @@ -901,6 +939,13 @@ public static Writer getSeqWriterFrame(Path path, Configuration job, int replica Writer.replication((short) (replication > 0 ? replication : 1))); } + public static Writer getSeqWriterArray(Path path, Configuration job, int replication) throws IOException { + return SequenceFile.createWriter(job, Writer.file(path), Writer.bufferSize(4096), + Writer.keyClass(LongWritable.class), Writer.valueClass(ArrayWrapper.class), + Writer.compression(getCompressionEncodingType(), getCompressionCodec()), + Writer.replication((short) (replication > 0 ? replication : 1))); + } + public static Writer getSeqWriterTensor(Path path, Configuration job, int replication) throws IOException { return SequenceFile.createWriter(job, Writer.file(path), Writer.bufferSize(4096), Writer.replication((short) (replication > 0 ? replication : 1)), From 4f75d4cb4863b4642271cf4aef3669321a407ee1 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 21:44:47 +0100 Subject: [PATCH 26/81] compressed writer --- .../sysds/runtime/io/FrameWriterCompressed.java | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameWriterCompressed.java b/src/main/java/org/apache/sysds/runtime/io/FrameWriterCompressed.java index 82c5a08e2c0..2e4c3d5ac3f 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameWriterCompressed.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameWriterCompressed.java @@ -19,14 +19,13 @@ package org.apache.sysds.runtime.io; -import java.io.IOException; +import java.util.List; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.mapred.JobConf; import org.apache.sysds.hops.OptimizerUtils; -import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.frame.data.lib.FrameLibCompress; +import org.apache.sysds.runtime.matrix.data.Pair; public class FrameWriterCompressed extends FrameWriterBinaryBlockParallel { @@ -37,11 +36,10 @@ public FrameWriterCompressed(boolean parallel) { } @Override - protected void writeBinaryBlockFrameToHDFS(Path path, JobConf job, FrameBlock src, long rlen, long clen) - throws IOException, DMLRuntimeException { + protected Pair>>, FrameBlock> extractDictionaries(FrameBlock src) { int k = parallel ? OptimizerUtils.getParallelBinaryWriteParallelism() : 1; FrameBlock compressed = FrameLibCompress.compress(src, k); - super.writeBinaryBlockFrameToHDFS(path, job, compressed, rlen, clen); + return super.extractDictionaries(compressed); } } From 696c9c4891e47cf6226feffe7837b198a3f44c70 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 22:16:15 +0100 Subject: [PATCH 27/81] we try this optimization --- .../estim/encoding/DenseEncoding.java | 40 ++++++------------- 1 file changed, 12 insertions(+), 28 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java index 8fc9d96f728..aae28599398 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java @@ -247,7 +247,7 @@ else if(ret instanceof MapToCharPByte) private final void combineDenseWIthHashMapPByteOut(final AMapToData lm, final AMapToData rm, final int size, final int nVL, final MapToCharPByte ret, HashMapLongInt m) { for(int r = 0; r < size; r++) - addValHashMapCharByte(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, ret); + addValHashMapCharByte(calculateID(lm, rm, nVL, r), r, m, ret); } private final void combineDenseWIthHashMapCharOut(final AMapToData lm, final AMapToData rm, final int size, @@ -261,7 +261,7 @@ private final void combineDenseWIthHashMapCharOut(final AMapToData lm, final AMa private final void combineDenseWIthHashMapCharOutGeneric(final AMapToData lm, final AMapToData rm, final int size, final int nVL, final MapToChar ret, HashMapLongInt m) { for(int r = 0; r < size; r++) - addValHashMapChar(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, ret); + addValHashMapChar(calculateID(lm, rm, nVL, r), r, m, ret); } private final void combineDenseWIthHashMapAllChar(final AMapToData lm, final AMapToData rm, final int size, @@ -276,43 +276,27 @@ private final void combineDenseWIthHashMapAllChar(final AMapToData lm, final AMa protected final void combineDenseWithHashMapGeneric(final AMapToData lm, final AMapToData rm, final int size, final int nVL, final AMapToData ret, HashMapLongInt m) { for(int r = 0; r < size; r++) - addValHashMap(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, ret); + addValHashMap(calculateID(lm, rm, nVL, r), r, m, ret); } protected final DenseEncoding combineDenseWithMapToData(final AMapToData lm, final AMapToData rm, final int size, final int nVL, final AMapToData ret, final int maxUnique, final AMapToData m) { - if(m instanceof MapToChar) - return combineDenseWithMapToDataToChar(lm, rm, size, nVL, ret, maxUnique, (MapToChar) m); - else - return combineDenseWithMapToDataGeneric(lm, rm, size, nVL, ret, maxUnique, m); - - } - - protected final DenseEncoding combineDenseWithMapToDataToChar(final AMapToData lm, final AMapToData rm, - final int size, final int nVL, final AMapToData ret, final int maxUnique, final MapToChar m) { int newUID = 1; - for(int r = 0; r < size; r++) - newUID = addValMapToDataChar(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, newUID, ret); + newUID = addValRange(lm, rm, size, nVL, ret, m, newUID, 0, size); ret.setUnique(newUID - 1); return new DenseEncoding(ret); + } - protected final DenseEncoding combineDenseWithMapToDataGeneric(final AMapToData lm, final AMapToData rm, - final int size, final int nVL, final AMapToData ret, final int maxUnique, final AMapToData m) { - int newUID = 1; - for(int r = 0; r < size; r++) - newUID = addValMapToData(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, newUID, ret); - ret.setUnique(newUID - 1); - return new DenseEncoding(ret); + private int addValRange(final AMapToData lm, final AMapToData rm, final int size, final int nVL, + final AMapToData ret, final AMapToData m, int newUID, int start, int end) { + for(int r = start; r < end; r++) + newUID = addValMapToData(calculateID(lm, rm, nVL, r), r, m, newUID, ret); + return newUID; } - protected static int addValMapToDataChar(final int nv, final int r, final MapToChar map, int newId, - final AMapToData d) { - int mv = map.getIndex(nv); - if(mv == 0) - mv = map.setAndGet(nv, newId++); - d.set(r, mv - 1); - return newId; + private int calculateID(final AMapToData lm, final AMapToData rm, final int nVL, int r) { + return lm.getIndex(r) + rm.getIndex(r) * nVL; } protected static int addValMapToData(final int nv, final int r, final AMapToData map, int newId, From 0cd1d8d8fe3cd25751ad2f598cc91fe2e0d76600 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 22:31:13 +0100 Subject: [PATCH 28/81] gammaSquared --- .../runtime/compress/estim/sample/HassAndStokes.java | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/sample/HassAndStokes.java b/src/main/java/org/apache/sysds/runtime/compress/estim/sample/HassAndStokes.java index 06f191ca6e7..fc0b9e11557 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/sample/HassAndStokes.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/sample/HassAndStokes.java @@ -132,9 +132,11 @@ private static double getDuj2aEstimate(double q, int f[], int n, int dn, double private static double getGammaSquared(double D, int[] f, int n, int N) { // Computes the "squared coefficient of variation" based on a given initial estimate D (Eq 16). double gamma = 0; - for(int i = 1; i <= f.length; i++) - if(f[i - 1] != 0) - gamma += i * (i - 1) * f[i - 1]; + for(int i = 2; i <= f.length; i++){ + int im1 = i - 1; + // if(f[im1] != 0) + gamma += i * (im1) * f[im1]; + } gamma *= D / n / n; gamma += D / N - 1; return Math.max(0, gamma); From 1d93f2f1de3fa4bbe5698961110f3f8f6deef76c Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 22:41:39 +0100 Subject: [PATCH 29/81] reduce precalculated --- .../org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java index 51cc34fa9e4..9b5828c3bbf 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeGreedy.java @@ -224,7 +224,8 @@ public Object call() throws Exception { final int maxCombined = c1i.getNumVals() * c2i.getNumVals(); if(maxCombined < 0 // int overflow - || maxCombined > c1i.getNumRows()) // higher combined than number of rows. + || maxCombined > c1i.getNumRows() // higher than number of rows + || maxCombined > 100000) // higher than 100k ... then lets not precalculate it. return null; final IColIndex c = _c1._indexes.combine(_c2._indexes); From ce985585f619f391b3fd1eb8e3bc496b395b2c59 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 23:17:28 +0100 Subject: [PATCH 30/81] entry set for HasMapToInt --- .../frame/data/columns/HashMapToInt.java | 132 +++++++++++++----- 1 file changed, 100 insertions(+), 32 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java index bbc92ec7a8e..e63c3266126 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java @@ -20,7 +20,9 @@ package org.apache.sysds.runtime.frame.data.columns; import java.io.Serializable; +import java.util.AbstractSet; import java.util.Collection; +import java.util.Iterator; import java.util.Map; import java.util.Set; import java.util.function.BiConsumer; @@ -32,13 +34,12 @@ public class HashMapToInt implements Map, Serializable, Cloneable static final int MAXIMUM_CAPACITY = 1 << 30; static final float DEFAULT_LOAD_FACTOR = 0.75f; - - static class Node { + static class Node implements Entry { final K key; int value; Node next; - Node( K key, int value, Node next) { + Node(K key, int value, Node next) { this.key = key; this.value = value; this.next = next; @@ -47,6 +48,21 @@ static class Node { public final void setNext(Node n) { next = n; } + + @Override + public K getKey() { + return key; + } + + @Override + public Integer getValue() { + return value; + } + + @Override + public Integer setValue(Integer value) { + return this.value = value; + } } protected Node[] buckets; @@ -58,12 +74,9 @@ public HashMapToInt(int capacity) { alloc(Math.max(capacity, 16)); } - - - @SuppressWarnings({"unchecked"}) protected void alloc(int size) { - Node[] tmp = (Node[])new Node[size]; + Node[] tmp = (Node[]) new Node[size]; buckets = tmp; } @@ -102,15 +115,16 @@ public int getI(K key) { final int ix = hash(key); Node b = buckets[ix]; if(b != null) { - do{ + do { if(b.key.equals(key)) return b.value; - } while((b = b.next) != null); + } + while((b = b.next) != null); } return -1; } - public int hash(K key){ + public int hash(K key) { return Math.abs(key.hashCode()) % buckets.length; } @@ -123,8 +137,8 @@ public Integer put(K key, Integer value) { return null; } - @Override - public Integer putIfAbsent(K key, Integer value){ + @Override + public Integer putIfAbsent(K key, Integer value) { int i = putIfAbsentI(key, value); if(i != -1) return i; @@ -132,22 +146,21 @@ public Integer putIfAbsent(K key, Integer value){ return null; } - public int putIfAbsentI(K key, int value){ + public int putIfAbsentI(K key, int value) { final int ix = hash(key); Node b = buckets[ix]; - if( b == null) + if(b == null) return createBucket(ix, key, value); - else + else return putIfAbsentBucket(ix, key, value); } - private int putIfAbsentBucket(int ix, K key, int value) { Node b = buckets[ix]; - while(true){ + while(true) { if(b.key.equals(key)) return b.value; - if(b.next == null){ + if(b.next == null) { b.next = new Node<>(key, value, null); size++; return -1; @@ -166,21 +179,21 @@ public int putI(K key, int value) { } private int createBucket(int ix, K key, int value) { - buckets[ix] = new Node(key, value, null ); + buckets[ix] = new Node(key, value, null); size++; return -1; } private int addToBucket(int ix, K key, int value) { Node b = buckets[ix]; - while(true){ + while(true) { - if(b.key.equals(key)){ + if(b.key.equals(key)) { int tmp = b.value; b.value = value; return tmp; } - if(b.next == null){ + if(b.next == null) { b.next = new Node<>(key, value, null); size++; return -1; @@ -215,15 +228,15 @@ public Collection values() { } @Override - public Set> entrySet() { - throw new UnsupportedOperationException("Unimplemented method 'entrySet'"); + public Set> entrySet() { + return new EntrySet(); } - @Override + @Override public void forEach(BiConsumer action) { - for(Node n : buckets){ - if(n != null){ - do{ + for(Node n : buckets) { + if(n != null) { + do { action.accept(n.key, n.value); } while((n = n.next) != null); @@ -231,13 +244,68 @@ public void forEach(BiConsumer action) { } } - @Override - public String toString(){ + @Override + public String toString() { StringBuilder sb = new StringBuilder(); - this.forEach((k,v) -> { - sb.append("("+k +"→" + v+")"); + this.forEach((k, v) -> { + sb.append("(" + k + "→" + v + ")"); }); return sb.toString(); } + private final class EntrySet extends AbstractSet> { + + @Override + public int size() { + return size; + } + + @Override + public Iterator> iterator() { + return new EntryIterator(); + } + + } + + private final class EntryIterator implements Iterator> { + Node next; + int bucketId = 0; + + protected EntryIterator() { + for(; bucketId < buckets.length; bucketId++) { + if(buckets[bucketId] != null) { + next = buckets[bucketId]; + break; + } + } + } + + @Override + public boolean hasNext() { + return next != null; + } + + @Override + public Entry next() { + + Node e = next; + + if(e.next != null) + next = e.next; + else { + for(; ++bucketId < buckets.length; bucketId++) { + if(buckets[bucketId] != null) { + next = buckets[bucketId]; + break; + } + } + if(bucketId == buckets.length) + next = null; + } + + return e; + } + + } + } From 987de0956a51d3d5cbe92a91f718e40aaa59bdc7 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 23:38:00 +0100 Subject: [PATCH 31/81] compress without specifying unique --- .../frame/data/compress/CompressedFrameBlockFactory.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java b/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java index 6cddc728fa0..de7031c7c01 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java @@ -208,7 +208,7 @@ private Array compressColFinally(int i, final Array a, final ArrayCompress Timing time = LOG.isDebugEnabled() ? new Timing(true) : null; if(s.bestType != null && s.shouldCompress) { if(s.bestType == FrameArrayType.DDC) - compressedColumns[i] = DDCArray.compressToDDC(a, s.nUnique); + compressedColumns[i] = DDCArray.compressToDDC(a); else throw new RuntimeException("Unsupported frame compression encoding : " + s.bestType); } From 9bf572f092469c66d229852b94270a16e4982cd3 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 13 Jan 2025 23:38:38 +0100 Subject: [PATCH 32/81] sampled all --- .../frame/data/compress/CompressedFrameBlockFactory.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java b/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java index de7031c7c01..75cd304a9f0 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/compress/CompressedFrameBlockFactory.java @@ -208,7 +208,7 @@ private Array compressColFinally(int i, final Array a, final ArrayCompress Timing time = LOG.isDebugEnabled() ? new Timing(true) : null; if(s.bestType != null && s.shouldCompress) { if(s.bestType == FrameArrayType.DDC) - compressedColumns[i] = DDCArray.compressToDDC(a); + compressedColumns[i] = DDCArray.compressToDDC(a, s.sampledAllRows ? s.nUnique : Integer.MAX_VALUE); else throw new RuntimeException("Unsupported frame compression encoding : " + s.bestType); } From d0d09dc7efabe6de4030e2052682c4de07d18e23 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 01:05:38 +0100 Subject: [PATCH 33/81] dict writable --- .../runtime/compress/io/DictWritable.java | 56 ++++++++++++++++++- .../runtime/compress/io/WriterCompressed.java | 17 ++---- 2 files changed, 59 insertions(+), 14 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/io/DictWritable.java b/src/main/java/org/apache/sysds/runtime/compress/io/DictWritable.java index 6f5bf1dfef7..29ecf02f017 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/io/DictWritable.java +++ b/src/main/java/org/apache/sysds/runtime/compress/io/DictWritable.java @@ -24,7 +24,11 @@ import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; +import java.util.Map; +import java.util.Set; import org.apache.hadoop.io.Writable; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; @@ -32,6 +36,7 @@ public class DictWritable implements Writable, Serializable { private static final long serialVersionUID = 731937201435558L; + public List dicts; public DictWritable() { @@ -44,19 +49,65 @@ protected DictWritable(List dicts) { @Override public void write(DataOutput out) throws IOException { + // the dicts can contain duplicates. + // to avoid writing duplicates we run though once to detect them + Set ud = new HashSet<>(); + for(IDictionary d: dicts){ + if(ud.contains(d)){ + writeWithDuplicates(out); + return; + } + ud.add(d); + } + out.writeInt(dicts.size()); for(int i = 0; i < dicts.size(); i++) dicts.get(i).write(out); } + private void writeWithDuplicates(DataOutput out) throws IOException { + // indicate that we use duplicate detection + out.writeInt(dicts.size() * -1); + Map m = new HashMap<>(); + + for(int i = 0; i < dicts.size(); i++){ + int id = m.getOrDefault(dicts.get(i), m.size() ); + out.writeInt(id); + + if(!m.containsKey(dicts.get(i))){ + m.put(dicts.get(i), m.size()); + dicts.get(i).write(out); + } + + } + } + @Override public void readFields(DataInput in) throws IOException { int s = in.readInt(); + if( s < 0){ + readFieldsWithDuplicates(Math.abs(s), in); + } + else{ + dicts = new ArrayList<>(s); + for(int i = 0; i < s; i++) + dicts.add(DictionaryFactory.read(in)); + } + } + + private void readFieldsWithDuplicates(int s, DataInput in) throws IOException { + dicts = new ArrayList<>(s); - for(int i = 0; i < s; i++) - dicts.add(DictionaryFactory.read(in)); + for(int i = 0; i < s; i++){ + int id = in.readInt(); + if(id < i) + dicts.set(i, dicts.get(id)); + else + dicts.add(DictionaryFactory.read(in)); + } } + @Override public String toString() { StringBuilder sb = new StringBuilder(); @@ -64,6 +115,7 @@ public String toString() { for(IDictionary d : dicts) { sb.append(d); sb.append("\n"); + } return sb.toString(); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java b/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java index b77d27f0804..cf39ca6fba9 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/io/WriterCompressed.java @@ -57,6 +57,7 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics; import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.utils.stats.InfrastructureAnalyzer; +import org.apache.sysds.runtime.util.HDFSTool; public final class WriterCompressed extends MatrixWriter { @@ -146,7 +147,7 @@ private void write(MatrixBlock src, final String fname, final int blen) throws I } fs = IOUtilFunctions.getFileSystem(new Path(fname), job); - + int k = OptimizerUtils.getParallelBinaryWriteParallelism(); k = Math.min(k, (int)(src.getInMemorySize() / InfrastructureAnalyzer.getBlockSize(fs))); @@ -213,8 +214,6 @@ private void writeMultiBlockCompressedSingleThread(MatrixBlock mb, final int rle throws IOException { try { final CompressedMatrixBlock cmb = (CompressedMatrixBlock) mb; - - setupWrite(); final Path path = new Path(fname); Writer w = generateWriter(job, path, fs); for(int bc = 0; bc * blen < clen; bc++) {// column blocks @@ -244,7 +243,6 @@ private void writeMultiBlockCompressedSingleThread(MatrixBlock mb, final int rle private void writeMultiBlockCompressedParallel(MatrixBlock b, final int rlen, final int clen, final int blen, int k) throws IOException { - setupWrite(); final ExecutorService pool = CommonThreadPool.get(k); try { final ArrayList> tasks = new ArrayList<>(); @@ -265,7 +263,8 @@ private void writeMultiBlockCompressedParallel(MatrixBlock b, final int rlen, fi final int colBlocks = (int) Math.ceil((double) clen / blen ); final int nBlocks = (int) Math.ceil((double) rlen / blen); final int blocksPerThread = Math.max(1, nBlocks * colBlocks / k ); - + HDFSTool.deleteFileIfExistOnHDFS(new Path(fname + ".dict"), job); + int i = 0; for(int bc = 0; bc * blen < clen; bc++) {// column blocks final int sC = bc * blen; @@ -307,13 +306,6 @@ private void writeMultiBlockCompressedParallel(MatrixBlock b, final int rlen, fi } } - private void setupWrite() throws IOException { - // final Path path = new Path(fname); - // final JobConf job = ConfigurationManager.getCachedJobConf(); - // HDFSTool.deleteFileIfExistOnHDFS(path, job); - // HDFSTool.createDirIfNotExistOnHDFS(path, DMLConfig.DEFAULT_SHARED_DIR_PERMISSION); - } - private Path getPath(int id) { return new Path(fname, IOUtilFunctions.getPartFileName(id)); } @@ -397,6 +389,7 @@ protected DictWriteTask(String fname, List dicts, int id) { public Object call() throws Exception { Path p = new Path(fname + ".dict", IOUtilFunctions.getPartFileName(id)); + HDFSTool.deleteFileIfExistOnHDFS(p, job); try(Writer w = SequenceFile.createWriter(job, Writer.file(p), // Writer.bufferSize(4096), // Writer.keyClass(DictWritable.K.class), // From 280bd1f930bd1491020d8ffc4e0586baf3a48565 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 01:38:03 +0100 Subject: [PATCH 34/81] not placeholder --- .../sysds/runtime/compress/colgroup/ColGroupSDCZeros.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java index c9f555797b6..69e0f776383 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java @@ -34,6 +34,7 @@ import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary; +import org.apache.sysds.runtime.compress.colgroup.dictionary.PlaceHolderDict; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; @@ -93,7 +94,7 @@ public static AColGroup create(IColIndex colIndices, int numRows, IDictionary di int[] cachedCounts) { if(dict == null) return new ColGroupEmpty(colIndices); - else if(data.getUnique() == 1) { + else if(data.getUnique() == 1 && !(dict instanceof PlaceHolderDict)) { MatrixBlock mb = dict.getMBDict(colIndices.size()).getMatrixBlock().slice(0, 0); return ColGroupSDCSingleZeros.create(colIndices, numRows, MatrixBlockDictionary.create(mb), offsets, null); } From b60b1ea3569f0c2dfcbd0f88e4d6e24553811a63 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 14:08:20 +0100 Subject: [PATCH 35/81] fix csv metadata parsing --- .../sysds/runtime/io/FrameReaderTextCSV.java | 219 +++++++++++------- .../io/FrameReaderTextCSVParallel.java | 48 ++-- .../sysds/runtime/io/IOUtilFunctions.java | 9 +- .../test/component/utils/IOUtilsTest.java | 44 ++++ .../TransformCSVFrameEncodeReadTest.java | 13 +- 5 files changed, 226 insertions(+), 107 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/component/utils/IOUtilsTest.java diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSV.java b/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSV.java index 6a94bcfd50d..8c138c6103c 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSV.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSV.java @@ -38,6 +38,7 @@ import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.transform.TfUtils; import org.apache.sysds.runtime.util.HDFSTool; @@ -118,22 +119,27 @@ protected final int readCSVFrameFromInputSplit(InputSplit split, InputFormat rlen) // in case this method is called wrongly + if(rl > rlen) // in case this method is called wrongly throw new DMLRuntimeException("Invalid offset"); // return (int) rlen; - boolean hasHeader = _props.hasHeader(); - boolean isFill = _props.isFill(); - double dfillValue = _props.getFillValue(); - String sfillValue = String.valueOf(_props.getFillValue()); - Set naValues = _props.getNAStrings(); - String delim = _props.getDelim(); - - // create record reader - RecordReader reader = informat.getRecordReader(split, job, Reporter.NULL); - LongWritable key = new LongWritable(); - Text value = new Text(); - int row = rl; - final int nCol = dest.getNumColumns(); + final boolean hasHeader = _props.hasHeader(); + final boolean isFill = _props.isFill(); + final double dfillValue = _props.getFillValue(); + final String sfillValue = String.valueOf(_props.getFillValue()); + final Set naValues = _props.getNAStrings(); + final String delim = _props.getDelim(); + final CellAssigner f; + if(naValues != null ) + f = FrameReaderTextCSV::assignCellGeneric; + else if(isFill && dfillValue != 0) + f = FrameReaderTextCSV::assignCellFill; + else + f = FrameReaderTextCSV::assignCellNoFill; + + final RecordReader reader = informat.getRecordReader(split, job, Reporter.NULL); + final LongWritable key = new LongWritable(); + final Text value = new Text(); + // handle header if existing if(first && hasHeader) { @@ -142,37 +148,18 @@ protected final int readCSVFrameFromInputSplit(InputSplit split, InputFormat[] destA = dest.getColumns(); while(reader.next(key, value)) // foreach line { - boolean emptyValuesFound = false; - String cellStr = IOUtilFunctions.trim(value.toString()); - parts = IOUtilFunctions.splitCSV(cellStr, delim, parts); - // sanity checks for empty values and number of columns - - final boolean mtdP = parts[0].equals(TfUtils.TXMTD_MVPREFIX); - final boolean mtdx = parts[0].equals(TfUtils.TXMTD_NDPREFIX); - // parse frame meta data (missing values / num distinct) - if(mtdP || mtdx) { - if(parts.length != dest.getNumColumns() + 1){ - LOG.warn("Invalid metadata "); - parts = null; - continue; - } - else if(mtdP) - for(int j = 0; j < dest.getNumColumns(); j++) - dest.getColumnMetadata(j).setMvValue(parts[j + 1]); - else if(mtdx) - for(int j = 0; j < dest.getNumColumns(); j++) - dest.getColumnMetadata(j).setNumDistinct(Long.parseLong(parts[j + 1])); - parts = null; + String line = value.toString(); + if(isMetaStart(line)){ + parseMeta(line, delim , dest); continue; } - assignColumns(row, nCol, dest, parts, naValues, isFill, dfillValue, sfillValue); - - IOUtilFunctions.checkAndRaiseErrorCSVEmptyField(cellStr, isFill, emptyValuesFound); - IOUtilFunctions.checkAndRaiseErrorCSVNumColumns("", cellStr, parts, clen); + + parseLine(line, delim, destA, row, (int) clen, dfillValue, sfillValue, isFill, naValues, f); row++; } } @@ -186,43 +173,104 @@ else if(mtdx) return row; } - private boolean assignColumns(int row, int nCol, FrameBlock dest, String[] parts, Set naValues, - boolean isFill, double dfillValue, String sfillValue) { - if(!isFill && naValues == null) - return assignColumnsNoFillNoNan(row, nCol, dest, parts); - else - return assignColumnsGeneric(row, nCol, dest, parts, naValues, isFill, dfillValue, sfillValue); - } + private static boolean isMetaStart(String s){ + return s.charAt(0) == '#' && s.substring(0, 5).equals("#Meta"); - private boolean assignColumnsGeneric(int row, int nCol, FrameBlock dest, String[] parts, Set naValues, - boolean isFill, double dfillValue, String sfillValue) { - boolean emptyValuesFound = false; - for(int col = 0; col < nCol; col++) { - String part = IOUtilFunctions.trim(parts[col]); - if(part.isEmpty() || (naValues != null && naValues.contains(part))) { - if(isFill && dfillValue != 0) - dest.set(row, col, sfillValue); - emptyValuesFound = true; + } + + private static void parseMeta(String s, String delim, FrameBlock dest){ + + String[] parts = IOUtilFunctions.splitCSV(s, delim); + + final boolean mtdP = parts[0].equals(TfUtils.TXMTD_MVPREFIX); + final boolean mtdx = parts[0].equals(TfUtils.TXMTD_NDPREFIX); + + if(parts.length != dest.getNumColumns() + 1){ + LOG.warn("Invalid metadata "); + parts = null; + return; } - else - dest.set(row, col, part); + else if(mtdP) + for(int j = 0; j < dest.getNumColumns(); j++) + dest.getColumnMetadata(j).setMvValue(parts[j + 1]); + else if(mtdx) + for(int j = 0; j < dest.getNumColumns(); j++) + dest.getColumnMetadata(j).setNumDistinct(Long.parseLong(parts[j + 1])); + parts = null; + + } + + private static void parseLine(final String cellStr, final String delim, final Array[] destA, final int row, final int clen, final double dfillValue, + final String sfillValue, final boolean isFill, final Set naValues,final CellAssigner assigner) { + try { + final String trimmed = IOUtilFunctions.trim( cellStr); + final int len = trimmed.length(); + final int delimLen = delim.length(); + parseLineSpecialized(trimmed, delim, destA, row, dfillValue, sfillValue, isFill, naValues, len, delimLen, assigner); } + catch(Exception e) { + throw new RuntimeException("failed to parse: " + cellStr, e); + } + } + + private static void parseLineSpecialized(String cellStr, String delim, Array[] destA, int row, double dfillValue, String sfillValue, + boolean isFill, Set naValues, final int len, final int delimLen, final CellAssigner assigner) { + int from = 0, to = 0, c = 0; + while(from < len) { // for all tokens + to = IOUtilFunctions.getTo(cellStr, from, delim, len, delimLen); + String s = cellStr.substring(from, to); + assigner.assign(row, destA[c], s, to - from, naValues, isFill, dfillValue, sfillValue); + c++; + from = to + delimLen; + } + } - return emptyValuesFound; + @FunctionalInterface + private interface CellAssigner{ + void assign(int row, Array dest, String val, int length, Set naValues, boolean isFill, + double dfillValue, String sfillValue); } - private boolean assignColumnsNoFillNoNan(int row, int nCol, FrameBlock dest, String[] parts){ - - boolean emptyValuesFound = false; - for(int col = 0; col < nCol; col++) { - String part = IOUtilFunctions.trim(parts[col]); - if(part.isEmpty()) - emptyValuesFound = true; + + private static void assignCellNoFill(int row, Array dest, String val, int length, Set naValues, boolean isFill, + double dfillValue, String sfillValue) { + if(length != 0){ + final String part = IOUtilFunctions.trim(val, length); + if(part.isEmpty()) + return; + dest.set(row, part); + } + } + + + private static void assignCellFill(int row, Array dest, String val, int length, Set naValues, boolean isFill, + double dfillValue, String sfillValue) { + if(length == 0){ + dest.set(row, sfillValue); + } else { + final String part = IOUtilFunctions.trim(val, length); + if(part == null || part.isEmpty()) + dest.set(row, sfillValue); else - dest.set(row, col, part); + dest.set(row, part); } + } - return emptyValuesFound; + private static void assignCellGeneric(int row, Array dest, String val, int length, Set naValues, boolean isFill, + double dfillValue, String sfillValue) { + if(length == 0) { + if(isFill && dfillValue != 0) + dest.set(row, sfillValue); + } + else { + final String part = IOUtilFunctions.trim(val, length); + if(part == null || part.isEmpty() || (naValues != null && naValues.contains(part))) { + if(isFill && dfillValue != 0) + dest.set(row, sfillValue); + } + else + dest.set(row, part); + } } protected Pair computeCSVSize(Path path, JobConf job, FileSystem fs) throws IOException { @@ -248,25 +296,34 @@ protected static long countLinesInSplit(InputSplit split, TextInputFormat inForm throws IOException { RecordReader reader = inFormat.getRecordReader(split, job, Reporter.NULL); - int nrow = 0; try { - LongWritable key = new LongWritable(); - Text value = new Text(); - // ignore header of first split - if(header) - reader.next(key, value); - while(reader.next(key, value)) { - // note the metadata can be located at any row when spark - // (but only at beginning of individual part files) + return countLinesInReader(reader, header); + } + finally { + IOUtilFunctions.closeSilently(reader); + } + } + + private static int countLinesInReader(RecordReader reader, boolean header) + throws IOException { + final LongWritable key = new LongWritable(); + final Text value = new Text(); + + int nrow = 0; + // ignore header of first split + if(header) + reader.next(key, value); + while(reader.next(key, value)) { + // (but only at beginning of individual part files) + if(nrow < 3){ String sval = IOUtilFunctions.trim(value.toString()); - boolean containsMTD = nrow<3 && + boolean containsMTD = (sval.startsWith(TfUtils.TXMTD_MVPREFIX) || sval.startsWith(TfUtils.TXMTD_NDPREFIX)); nrow += containsMTD ? 0 : 1; } - } - finally { - IOUtilFunctions.closeSilently(reader); + else + nrow++; } return nrow; } diff --git a/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSVParallel.java b/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSVParallel.java index 05a259bf6a8..9ce3459d66e 100644 --- a/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSVParallel.java +++ b/src/main/java/org/apache/sysds/runtime/io/FrameReaderTextCSVParallel.java @@ -38,6 +38,7 @@ import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.HDFSTool; +import org.apache.sysds.utils.stats.Timing; /** * Multi-threaded frame text csv reader. @@ -54,7 +55,8 @@ protected void readCSVFrameFromHDFS( Path path, JobConf job, FileSystem fs, FrameBlock dest, ValueType[] schema, String[] names, long rlen, long clen) throws IOException { - int numThreads = OptimizerUtils.getParallelTextReadParallelism(); + Timing time = new Timing(true); + final int numThreads = OptimizerUtils.getParallelTextReadParallelism(); TextInputFormat informat = new TextInputFormat(); informat.configure(job); @@ -62,29 +64,35 @@ protected void readCSVFrameFromHDFS( Path path, JobConf job, FileSystem fs, if(HDFSTool.isDirectory(fs, path)) splits = IOUtilFunctions.sortInputSplits(splits); - ExecutorService pool = CommonThreadPool.get(numThreads); - try { - // get number of threads pool to use the common thread pool. - //compute num rows per split - ArrayList tasks = new ArrayList<>(); - for( int i=0; i> cret = pool.invokeAll(tasks); + final ExecutorService pool = CommonThreadPool.get(numThreads); + try { + if(splits.length == 1){ + new ReadRowsTask(splits[0], informat, job, dest, 0, true).call(); + return; + } + //compute num rows per split + ArrayList> cret = new ArrayList<>(); + for( int i=0; i offsets = new ArrayList<>(); - for( Future count : cret ) { - offsets.add(offset); - offset += count.get(); + ArrayList> tasks2 = new ArrayList<>(); + for( int i=0; i tasks2 = new ArrayList<>(); - for( int i=0; i a : tasks2) + a.get(); + LOG.debug("Finished Reading CSV : " + time.stop()); } catch (Exception e) { throw new IOException("Failed parallel read of text csv input.", e); @@ -137,6 +145,7 @@ private static class CountRowsTask implements Callable { private JobConf _job; private boolean _hasHeader; + public CountRowsTask(InputSplit split, TextInputFormat informat, JobConf job, boolean hasHeader) { _split = split; _informat = informat; @@ -146,7 +155,8 @@ public CountRowsTask(InputSplit split, TextInputFormat informat, JobConf job, bo @Override public Long call() throws Exception { - return countLinesInSplit(_split, _informat, _job, _hasHeader); + long count = countLinesInSplit(_split, _informat, _job, _hasHeader); + return count; } } diff --git a/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java b/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java index e4feb3ed756..c6085cb8960 100644 --- a/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java +++ b/src/main/java/org/apache/sysds/runtime/io/IOUtilFunctions.java @@ -76,6 +76,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.MatrixCell; import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.transform.TfUtils; import org.apache.sysds.runtime.util.LocalFileUtils; import io.airlift.compress.lzo.LzoCodec; @@ -695,10 +696,10 @@ public static int countNumColumnsCSV(InputSplit[] splits, InputFormat informat, try { if( reader.next(key, value) ) { boolean hasValue = true; - // if( value.toString().startsWith(TfUtils.TXMTD_MVPREFIX) ) - // hasValue = reader.next(key, value); - // if( value.toString().startsWith(TfUtils.TXMTD_NDPREFIX) ) - // hasValue = reader.next(key, value); + if( value.toString().startsWith(TfUtils.TXMTD_MVPREFIX) ) + hasValue = reader.next(key, value); + if( value.toString().startsWith(TfUtils.TXMTD_NDPREFIX) ) + hasValue = reader.next(key, value); String row = value.toString().trim(); if( hasValue && !row.isEmpty() ) { ncol = IOUtilFunctions.countTokensCSV(row, delim); diff --git a/src/test/java/org/apache/sysds/test/component/utils/IOUtilsTest.java b/src/test/java/org/apache/sysds/test/component/utils/IOUtilsTest.java new file mode 100644 index 00000000000..52be349e70a --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/utils/IOUtilsTest.java @@ -0,0 +1,44 @@ +package org.apache.sysds.test.component.utils; + +import static org.junit.Assert.assertEquals; + +import org.apache.sysds.runtime.io.IOUtilFunctions; +import org.junit.Test; + +public class IOUtilsTest { + + + @Test + public void getTo(){ + String in = ",\"yyy\"·4,"; + assertEquals(0, getTo(in, 0, ",")); + assertEquals(8, getTo(in, 1, ",")); + assertEquals("\"yyy\"·4", in.substring(1, getTo(in, 1, ","))); + } + + @Test + public void getTo2(){ + String in = ",y,"; + assertEquals(0,getTo(in, 0, ",")); + assertEquals(2,getTo(in, 1, ",")); + } + + @Test + public void getTo3(){ + String in = "a,b,c"; + assertEquals("a",in.substring(0,getTo(in, 0, ","))); + assertEquals("b",in.substring(2,getTo(in, 2, ","))); + assertEquals("c",in.substring(4,getTo(in, 4, ","))); + } + + @Test + public void getTo4(){ + String in = "a,\",\",c"; + assertEquals("a",in.substring(0,getTo(in, 0, ","))); + assertEquals("\",\"",in.substring(2,getTo(in, 2, ","))); + } + + private int getTo(String in, int from, String delim){ + return IOUtilFunctions.getTo(in, from, ",", in.length(), 1); + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TransformCSVFrameEncodeReadTest.java b/src/test/java/org/apache/sysds/test/functions/transform/TransformCSVFrameEncodeReadTest.java index f66fc1db3c2..ede8bf96655 100644 --- a/src/test/java/org/apache/sysds/test/functions/transform/TransformCSVFrameEncodeReadTest.java +++ b/src/test/java/org/apache/sysds/test/functions/transform/TransformCSVFrameEncodeReadTest.java @@ -21,6 +21,11 @@ import static org.junit.Assert.fail; +import java.io.File; +import java.util.Scanner; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.ExecMode; import org.apache.sysds.runtime.frame.data.FrameBlock; @@ -35,8 +40,9 @@ import org.junit.Test; -public class TransformCSVFrameEncodeReadTest extends AutomatedTestBase -{ +public class TransformCSVFrameEncodeReadTest extends AutomatedTestBase { + protected static final Log LOG = LogFactory.getLog(TransformCSVFrameEncodeReadTest.class.getName()); + private final static String TEST_NAME1 = "TransformCSVFrameEncodeRead"; private final static String TEST_DIR = "functions/transform/"; private final static String TEST_CLASS_DIR = TEST_DIR + TransformCSVFrameEncodeReadTest.class.getSimpleName() + "/"; @@ -136,7 +142,7 @@ private void runTransformTest( ExecMode rt, String ofmt, boolean subset, boolean DATASET_DIR + DATASET, String.valueOf(nrows), output("R") }; String stdOut = runTest(null).toString(); - + //read input/output and compare FrameReader reader2 = parRead ? new FrameReaderTextCSVParallel( new FileFormatPropertiesCSV() ) : @@ -158,6 +164,7 @@ private void runTransformTest( ExecMode rt, String ofmt, boolean subset, boolean } catch(Exception ex) { + ex.printStackTrace(); throw new RuntimeException(ex); } finally { From 7b2b4206762350d8c301f97a87d3e3a828124ee1 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 15:13:44 +0100 Subject: [PATCH 36/81] license --- .../test/component/utils/IOUtilsTest.java | 52 +++++++++++++------ 1 file changed, 35 insertions(+), 17 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/component/utils/IOUtilsTest.java b/src/test/java/org/apache/sysds/test/component/utils/IOUtilsTest.java index 52be349e70a..440c116fbf8 100644 --- a/src/test/java/org/apache/sysds/test/component/utils/IOUtilsTest.java +++ b/src/test/java/org/apache/sysds/test/component/utils/IOUtilsTest.java @@ -1,3 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + package org.apache.sysds.test.component.utils; import static org.junit.Assert.assertEquals; @@ -6,39 +25,38 @@ import org.junit.Test; public class IOUtilsTest { - - @Test - public void getTo(){ + @Test + public void getTo() { String in = ",\"yyy\"·4,"; assertEquals(0, getTo(in, 0, ",")); assertEquals(8, getTo(in, 1, ",")); assertEquals("\"yyy\"·4", in.substring(1, getTo(in, 1, ","))); } - @Test - public void getTo2(){ + @Test + public void getTo2() { String in = ",y,"; - assertEquals(0,getTo(in, 0, ",")); - assertEquals(2,getTo(in, 1, ",")); + assertEquals(0, getTo(in, 0, ",")); + assertEquals(2, getTo(in, 1, ",")); } - @Test - public void getTo3(){ + @Test + public void getTo3() { String in = "a,b,c"; - assertEquals("a",in.substring(0,getTo(in, 0, ","))); - assertEquals("b",in.substring(2,getTo(in, 2, ","))); - assertEquals("c",in.substring(4,getTo(in, 4, ","))); + assertEquals("a", in.substring(0, getTo(in, 0, ","))); + assertEquals("b", in.substring(2, getTo(in, 2, ","))); + assertEquals("c", in.substring(4, getTo(in, 4, ","))); } - @Test - public void getTo4(){ + @Test + public void getTo4() { String in = "a,\",\",c"; - assertEquals("a",in.substring(0,getTo(in, 0, ","))); - assertEquals("\",\"",in.substring(2,getTo(in, 2, ","))); + assertEquals("a", in.substring(0, getTo(in, 0, ","))); + assertEquals("\",\"", in.substring(2, getTo(in, 2, ","))); } - private int getTo(String in, int from, String delim){ + private int getTo(String in, int from, String delim) { return IOUtilFunctions.getTo(in, from, ",", in.length(), 1); } } From 2ad1bf7b4f8ca3314f715068be81f8e73f6a7109 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 15:17:46 +0100 Subject: [PATCH 37/81] fix contains key --- .../apache/sysds/runtime/frame/data/columns/HashMapToInt.java | 2 +- .../test/functions/transform/TransformApplyUnknownsTest.java | 1 + .../functions/transform/TransformCSVFrameEncodeReadTest.java | 3 --- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java index e63c3266126..e9db1264f68 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java @@ -93,7 +93,7 @@ public boolean isEmpty() { @Override @SuppressWarnings({"unchecked"}) public boolean containsKey(Object key) { - return get((K) key) != -1; + return getI((K) key) != -1; } @Override diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TransformApplyUnknownsTest.java b/src/test/java/org/apache/sysds/test/functions/transform/TransformApplyUnknownsTest.java index e4178fe1d03..205ccf75c1b 100644 --- a/src/test/java/org/apache/sysds/test/functions/transform/TransformApplyUnknownsTest.java +++ b/src/test/java/org/apache/sysds/test/functions/transform/TransformApplyUnknownsTest.java @@ -68,6 +68,7 @@ public void testTransformApplyRecode() { Assert.assertTrue(Double.isNaN(out.get(i-1, 0))); } catch (DMLRuntimeException e) { + e.printStackTrace(); throw new RuntimeException(e); } } diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TransformCSVFrameEncodeReadTest.java b/src/test/java/org/apache/sysds/test/functions/transform/TransformCSVFrameEncodeReadTest.java index ede8bf96655..41872c3ba5c 100644 --- a/src/test/java/org/apache/sysds/test/functions/transform/TransformCSVFrameEncodeReadTest.java +++ b/src/test/java/org/apache/sysds/test/functions/transform/TransformCSVFrameEncodeReadTest.java @@ -21,9 +21,6 @@ import static org.junit.Assert.fail; -import java.io.File; -import java.util.Scanner; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.api.DMLScript; From b22f744030576d8d11d91349ebf3f2d405b4d3fe Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 15:35:16 +0100 Subject: [PATCH 38/81] writing to disk is painfull --- .../transform/TransformFrameEncodeWordEmbeddingMMTest.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingMMTest.java b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingMMTest.java index 966c7c465b2..8053f3a7426 100644 --- a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingMMTest.java +++ b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingMMTest.java @@ -72,16 +72,16 @@ private void runMatrixMultiplicationTest(String testname, Types.ExecMode rt) List strings = generateRandomStrings(rows, 10); // Generate the dictionary by assigning unique ID to each distinct token - Map map = writeDictToCsvFile(strings, baseDirectory + INPUT_DIR + "dict"); + Map map = writeDictToCsvFile(strings, input(testname + "dict")); // Create the dataset by repeating and shuffling the distinct tokens int factor = 32; rows *= factor; List stringsColumn = shuffleAndMultiplyStrings(strings, factor); - writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + "data"); + writeStringsToCsvFile(stringsColumn, input(testname + "data")); //run script - programArgs = new String[]{"-stats","-args", input("embeddings"), input("data"), input("dict"), input("factor"), output("result")}; + programArgs = new String[]{"-stats","-args", input("embeddings"), input(testname + "data"), input(testname + "dict"), input("factor"), output("result")}; runTest(true, EXCEPTION_NOT_EXPECTED, null, -1); // Manually derive the expected result From 4513643dd0c7d56af80130c7866a30187ba7ef5c Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 16:32:33 +0100 Subject: [PATCH 39/81] more tests for custom array --- .../runtime/frame/data/columns/Array.java | 16 +- .../frame/data/columns/HashMapToInt.java | 90 ++++----- .../encode/ColumnEncoderBagOfWords.java | 1 + .../frame/array/HashMapToIntTest.java | 176 ++++++++++++++++++ .../TransformFrameEncodeBagOfWords.java | 1 + 5 files changed, 233 insertions(+), 51 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/component/frame/array/HashMapToIntTest.java diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index bfc20de3649..5b61de8f9bd 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -53,7 +53,7 @@ public abstract class Array implements Writable { public static int ROW_PARALLELIZATION_THRESHOLD = 10000; /** A soft reference to a memorization of this arrays mapping, used in transformEncode */ - protected SoftReference> _rcdMapCache = null; + protected SoftReference> _rcdMapCache = null; /** The current allocated number of elements in this Array */ protected int _size; @@ -73,7 +73,7 @@ protected int newSize() { * * @return The cached recode map */ - public final SoftReference> getCache() { + public final SoftReference> getCache() { return _rcdMapCache; } @@ -82,7 +82,7 @@ public final SoftReference> getCache() { * * @param m The element to cache. */ - public final void setCache(SoftReference> m) { + public final void setCache(SoftReference> m) { _rcdMapCache = m; } @@ -126,11 +126,11 @@ public synchronized final Map getRecodeMap(int estimate) { * @throws ExecutionException if the parallel execution fails * @throws InterruptedException if the parallel execution fails */ - public synchronized final Map getRecodeMap(int estimate, ExecutorService pool, int k) + public synchronized final Map getRecodeMap(int estimate, ExecutorService pool, int k) throws InterruptedException, ExecutionException { // probe cache for existing map - Map map; - SoftReference> tmp = getCache(); + Map map; + SoftReference> tmp = getCache(); map = (tmp != null) ? tmp.get() : null; if(map != null) return map; @@ -207,12 +207,12 @@ private HashMapToInt parallelCreateRecodeMap(int estimate, ExecutorService po */ protected static void mergeRecodeMaps(HashMapToInt target, HashMapToInt from) { final List fromEntriesOrdered = new ArrayList<>(Collections.nCopies(from.size(), null)); - from.forEach((k,v) -> { + from.forEach((k, v) -> { fromEntriesOrdered.set(v - 1, k); }); int id = target.size(); for(T e : fromEntriesOrdered) { - if(target.putIfAbsent(e, id) == null) + if(target.putIfAbsentI(e, id) == -1) id++; } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java index e9db1264f68..3772174a00f 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java @@ -31,47 +31,13 @@ public class HashMapToInt implements Map, Serializable, Cloneable private static final long serialVersionUID = 3624988207265L; static final int DEFAULT_INITIAL_CAPACITY = 1 << 4; - static final int MAXIMUM_CAPACITY = 1 << 30; static final float DEFAULT_LOAD_FACTOR = 0.75f; - static class Node implements Entry { - final K key; - int value; - Node next; - - Node(K key, int value, Node next) { - this.key = key; - this.value = value; - this.next = next; - } - - public final void setNext(Node n) { - next = n; - } - - @Override - public K getKey() { - return key; - } - - @Override - public Integer getValue() { - return value; - } - - @Override - public Integer setValue(Integer value) { - return this.value = value; - } - } - protected Node[] buckets; - int size; - // protected List> keys; - // protected int[][] values; + protected int size; public HashMapToInt(int capacity) { - alloc(Math.max(capacity, 16)); + alloc(Math.max(capacity, DEFAULT_INITIAL_CAPACITY)); } @SuppressWarnings({"unchecked"}) @@ -98,7 +64,14 @@ public boolean containsKey(Object key) { @Override public boolean containsValue(Object value) { - throw new UnsupportedOperationException("Unimplemented method 'containsValue'"); + if(value instanceof Integer) { + for(Entry v : this.entrySet()) { + if(v.getValue().equals(value)) + return true; + } + } + return false; + } @Override @@ -161,7 +134,7 @@ private int putIfAbsentBucket(int ix, K key, int value) { if(b.key.equals(key)) return b.value; if(b.next == null) { - b.next = new Node<>(key, value, null); + b.setNext(new Node<>(key, value, null)); size++; return -1; } @@ -189,12 +162,12 @@ private int addToBucket(int ix, K key, int value) { while(true) { if(b.key.equals(key)) { - int tmp = b.value; - b.value = value; + int tmp = b.getValue(); + b.setValue(value); return tmp; } if(b.next == null) { - b.next = new Node<>(key, value, null); + b.setNext(new Node<>(key, value, null)); size++; return -1; } @@ -246,13 +219,44 @@ public void forEach(BiConsumer action) { @Override public String toString() { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = new StringBuilder(size()*3); this.forEach((k, v) -> { sb.append("(" + k + "→" + v + ")"); }); return sb.toString(); } + private static class Node implements Entry { + final K key; + int value; + Node next; + + Node(K key, int value, Node next) { + this.key = key; + this.value = value; + this.next = next; + } + + public final void setNext(Node n) { + next = n; + } + + @Override + public K getKey() { + return key; + } + + @Override + public Integer getValue() { + return value; + } + + @Override + public Integer setValue(Integer value) { + return this.value = value; + } + } + private final class EntrySet extends AbstractSet> { @Override @@ -299,7 +303,7 @@ public Entry next() { break; } } - if(bucketId == buckets.length) + if(bucketId >= buckets.length) next = null; } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java index badd9e200fb..08c38a92e50 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java @@ -383,6 +383,7 @@ public void writeExternal(ObjectOutput out) throws IOException { out.writeInt(_tokenDictionary == null ? 0 : _tokenDictionary.size()); if(_tokenDictionary != null) for(Map.Entry e : _tokenDictionary.entrySet()) { + System.out.println(e); out.writeUTF((String) e.getKey()); out.writeInt(e.getValue()); } diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/HashMapToIntTest.java b/src/test/java/org/apache/sysds/test/component/frame/array/HashMapToIntTest.java new file mode 100644 index 00000000000..728a27f646d --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/array/HashMapToIntTest.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sysds.test.component.frame.array; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Random; + +import org.apache.sysds.runtime.frame.data.columns.HashMapToInt; +import org.junit.Test; + +public class HashMapToIntTest { + + @Test + public void insert() { + Map m = new HashMapToInt<>(10); + m.put(1, 1); + assertTrue(m.containsKey(1)); + assertTrue(m.containsValue(1)); + } + + @Test + public void isEmpty() { + Map m = new HashMapToInt<>(10); + assertTrue(m.isEmpty()); + m.put(1, 1); + assertFalse(m.isEmpty()); + } + + @Test + public void insert10() { + + Map m = new HashMapToInt<>(10); + for(int i = 0; i < 10; i++) { + m.put(i, i); + assertFalse(m.isEmpty()); + assertTrue(m.containsKey(i)); + assertTrue(m.containsValue(i)); + } + + for(int i = 0; i < 10; i++) { + assertTrue(m.containsKey(i)); + assertTrue(m.containsValue(i)); + } + } + + @Test + public void forEach() { + + Map m = new HashMapToInt<>(10); + Map m2 = new HashMap<>(); + Random r = new Random(32); + for(int i = 0; i < 100; i++) { + int v1 = r.nextInt(); + int v2 = r.nextInt(); + m.put(v1, v2); + m2.put(v1, v2); + } + + assertEquals(m.size(), m2.size()); + for(Entry e : m2.entrySet()) { + assertTrue(m.containsKey(e.getKey())); + } + + assertEquals(m.size(), m2.size()); + for(Entry e : m.entrySet()) { + assertTrue(m2.containsKey(e.getKey())); + assertEquals(m.get(e.getKey()), m2.get(e.getKey())); + } + + } + + @Test + public void doNotContainKey() { + Map m = new HashMapToInt<>(10); + for(int i = 0; i < 100; i++) { + assertFalse(m.containsKey(i)); + assertFalse(m.containsValue(i * 10000)); + m.put(i, i * 10000); + assertTrue(m.containsKey(i)); + assertTrue(m.containsValue(i * 10000)); + assertEquals(m.get(i), Integer.valueOf(i * 10000)); + } + + } + + @Test + public void doNotContainValue() { + Map m = new HashMapToInt<>(10); + + assertFalse(m.containsValue(new Object())); + + } + + @Test + public void overwriteKey() { + Map m = new HashMapToInt<>(10); + + Integer v; + v = m.put(1, 10); + assertEquals(Integer.valueOf(10), m.get(Integer.valueOf(1))); + assertEquals(v, null); + v = m.put(1, 11); + assertEquals(v, Integer.valueOf(10)); + assertEquals(Integer.valueOf(11), m.get(Integer.valueOf(1))); + v = m.put(1, 12); + assertEquals(v, Integer.valueOf(11)); + assertEquals(Integer.valueOf(12), m.get(Integer.valueOf(1))); + } + + @Test + public void forEach2() { + Map m = new HashMapToInt<>(10); + for(int i = 900; i < 1000; i++) { + m.put(i, i * 32121523); + } + final Map m2 = new HashMap<>(); + m.forEach((k, v) -> m2.put(k, v)); + m2.forEach((k, v) -> assertTrue("key missing: " + k, m.containsKey(k))); + } + + @Test + public void testToString() { + Map m = new HashMapToInt<>(10); + + assertEquals(0, m.toString().length()); + m.put(555, 321); + String s = m.toString(); + assertTrue(s.contains("555")); + assertTrue(s.contains("321")); + } + + @Test + public void testSizeOfKeySet() { + Map m = new HashMapToInt<>(10); + for(int i = 0; i < 10; i++) { + m.put(i * 321, i * 3222); + assertEquals(m.size(), m.entrySet().size()); + } + } + + @Test + public void putIfAbsent(){ + Map m = new HashMapToInt<>(10); + for(int i = 0; i < 1000; i++) { + assertNull(m.putIfAbsent(i * 321, i * 3222)); + + } + + for(int i = 0; i < 1000; i++) { + assertEquals(i*3222,(int)m.putIfAbsent(i * 321, i * 3222)); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeBagOfWords.java b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeBagOfWords.java index b8ff0a9b7f4..f1cdd4b0f44 100644 --- a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeBagOfWords.java +++ b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeBagOfWords.java @@ -266,6 +266,7 @@ private void runTransformTest(String testname, ExecMode rt, boolean recode, bool } catch(Exception ex) { + ex.printStackTrace(); throw new RuntimeException(ex); } finally { From 3203299120d589733f0ef14c16c0d83ab3c21246 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 16:59:17 +0100 Subject: [PATCH 40/81] sum remove on combine --- .../apache/sysds/runtime/compress/lib/CLALibCombineGroups.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java index 3cd9c5e26d4..6045065a3b3 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java @@ -229,6 +229,8 @@ public static AColGroup combine(AColGroup a, AColGroup b, int nRow) { ret = combineUC(combinedColumns, a, b); try { + if(!CompressedMatrixBlock.debug) + return ret; double sumCombined = ret.getSum(nRow); double sumIndividualA = a.getSum(nRow); double sumIndividualB = b.getSum(nRow); From 8f4c7166063a8438dabc3c7b20660a6c3149d87d Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 17:02:54 +0100 Subject: [PATCH 41/81] combine uncompressed --- .../sysds/runtime/compress/lib/CLALibCombineGroups.java | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java index 6045065a3b3..97efdced7fb 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java @@ -214,10 +214,12 @@ public static AColGroup combine(AColGroup a, AColGroup b, int nRow) { IColIndex combinedColumns = ColIndexFactory.combine(a, b); // try to recompress a and b if uncompressed - if(a instanceof ColGroupUncompressed) + if( (a instanceof ColGroupUncompressed) && (b instanceof ColGroupUncompressed)){ + // do not try to compress if both are uncompressed + } + else if(a instanceof ColGroupUncompressed) a = a.recompress(); - - if(b instanceof ColGroupUncompressed) + else if(b instanceof ColGroupUncompressed) b = b.recompress(); long maxEst = (long) a.getNumValues() * b.getNumValues(); From 661b186512c3d820be641f767a0b1538bb33e12e Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 20:36:27 +0100 Subject: [PATCH 42/81] combine uncompressed error --- .../colgroup/ColGroupUncompressed.java | 7 ++- .../colgroup/ColGroupUncompressedArray.java | 40 ++++++++++++++++ .../compress/lib/CLALibCombineGroups.java | 1 - .../compress/lib/CLALibRightMultBy.java | 3 -- .../transform/encode/CompressedEncode.java | 48 ++++++++++++------- 5 files changed, 77 insertions(+), 22 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java index 0797cce6cdc..cf0959bba7f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java @@ -80,7 +80,12 @@ public class ColGroupUncompressed extends AColGroup { */ private final MatrixBlock _data; - private ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes) { + /** + * Do not use this constructor of column group uncompressed, instead uce the create constructor. + * @param mb The contained data. + * @param colIndexes Column indexes for this Columngroup + */ + protected ColGroupUncompressed(MatrixBlock mb, IColIndex colIndexes) { super(colIndexes); _data = mb; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java new file mode 100644 index 00000000000..6f4a6f4d40f --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.colgroup; + +import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.frame.data.columns.Array; + +/** + * Special sideways Compressed column group not supposed to be used outside of the compressed transform encode. + */ +public class ColGroupUncompressedArray extends ColGroupUncompressed { + + public final Array array; + public final int id; // columnID + + public ColGroupUncompressedArray(Array data, int id, IColIndex colIndexes){ + super(null, colIndexes); + this.array = data; + this.id = id; + } + + +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java index 97efdced7fb..1f660f8df0f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java @@ -149,7 +149,6 @@ public static AColGroup combineN(List groups, int nRows, ExecutorServ else { return combineNSingleAtATime(groups, nRows); } - } @SuppressWarnings("unchecked") diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java index 5d6de813fcf..966051cd8bd 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java @@ -28,7 +28,6 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.apache.sysds.api.DMLScript; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.runtime.DMLRuntimeException; @@ -43,8 +42,6 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.util.CommonThreadPool; -import org.apache.sysds.utils.DMLCompressionStatistics; -import org.apache.sysds.utils.stats.Timing; public final class CLALibRightMultBy { private static final Log LOG = LogFactory.getLog(CLALibRightMultBy.class.getName()); diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 893f06ea45e..2514cc1a344 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -37,7 +37,7 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; -import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressedArray; import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary; @@ -104,7 +104,9 @@ private MatrixBlock apply() throws Exception { final List encoders = enc.getColumnEncoders(); final List groups = isParallel() ? multiThread(encoders) : singleThread(encoders); final int cols = shiftGroups(groups); - final MatrixBlock mb = new CompressedMatrixBlock(in.getNumRows(), cols, -1, false, groups); + final CompressedMatrixBlock mb = new CompressedMatrixBlock(in.getNumRows(), cols, -1, false, groups); + + combineUncompressed(mb); mb.setNonZeros(nnz.get()); logging(mb); return mb; @@ -193,7 +195,7 @@ private AColGroup recodeToDummy(ColumnEncoderComposite c) throws Exception { if(containsNull && domain == 0) return new ColGroupEmpty(ColIndexFactory.create(1)); IColIndex colIndexes = ColIndexFactory.create(0, domain); - if(domain == 1 && !containsNull){ + if(domain == 1 && !containsNull) { nnz.addAndGet(in.getNumRows()); return ColGroupConst.create(colIndexes, new double[] {1}); } @@ -347,10 +349,10 @@ private AColGroup recode(ColumnEncoderComposite c) throws Exception { // int domain = c.getDomainSize(); IColIndex colIndexes = ColIndexFactory.create(1); - if(domain == 0 && containsNull){ + if(domain == 0 && containsNull) { return new ColGroupEmpty(colIndexes); } - if(domain == 1 && !containsNull){ + if(domain == 1 && !containsNull) { nnz.addAndGet(in.getNumRows()); return ColGroupConst.create(colIndexes, new double[] {1}); } @@ -397,14 +399,7 @@ private AColGroup passThroughNormal(ColumnEncoderComposite c, final IColInde if(a.getValueType() != ValueType.BOOLEAN // if not booleans && (stats == null || !stats.shouldCompress || stats.valueType != a.getValueType())) { - // stats.valueType; - double[] vals = (double[]) a.changeType(ValueType.FP64).get(); - - MatrixBlock col = new MatrixBlock(a.size(), 1, vals); - long nz = col.recomputeNonZeros(1); - - nnz.addAndGet(nz); - return ColGroupUncompressed.create(colIndexes, col, false); + return new ColGroupUncompressedArray(a, c._colID - 1,colIndexes); } else { boolean containsNull = a.containsNull(); @@ -532,10 +527,10 @@ private AColGroup hash(ColumnEncoderComposite c) { int domain = (int) CEHash.getK(); boolean nulls = a.containsNull(); IColIndex colIndexes = ColIndexFactory.create(0, 1); - if(domain == 0 && nulls){ + if(domain == 0 && nulls) { return new ColGroupEmpty(colIndexes); } - if(domain == 1 && !nulls){ + if(domain == 1 && !nulls) { nnz.addAndGet(in.getNumRows()); return ColGroupConst.create(colIndexes, new double[] {1}); } @@ -561,10 +556,10 @@ private AColGroup hashToDummy(ColumnEncoderComposite c) { int domain = (int) CEHash.getK(); boolean nulls = a.containsNull(); IColIndex colIndexes = ColIndexFactory.create(0, domain); - if(domain == 0 && nulls){ + if(domain == 0 && nulls) { return new ColGroupEmpty(ColIndexFactory.create(1)); } - if(domain == 1 && !nulls){ + if(domain == 1 && !nulls) { nnz.addAndGet(in.getNumRows()); return ColGroupConst.create(colIndexes, new double[] {1}); } @@ -609,6 +604,25 @@ private void estimateRCDMapSize(ColumnEncoderComposite c) { c._estNumDistincts = estDistCount; } + private void combineUncompressed(CompressedMatrixBlock mb) { + + List ucg = new ArrayList<>(); + List ret = new ArrayList<>(); + for(AColGroup g : mb.getColGroups()) { + if(g instanceof ColGroupUncompressedArray) + ucg.add((ColGroupUncompressedArray) g); + else + ret.add(g); + } + ret.add(combine(ucg)); + nnz.addAndGet(ret.get(ret.size()-1).getNumberNonZeros(in.getNumRows())); + mb.allocateColGroupList(ret); + } + + private AColGroup combine(List ucg) { + throw new NotImplementedException("Should combine " + ucg.size()); + } + private void logging(MatrixBlock mb) { if(LOG.isDebugEnabled()) { LOG.debug(String.format("Uncompressed transform encode Dense size: %16d", mb.estimateSizeDenseInMemory())); From 7bedb88f498f7735cde6228071da4b60687d0ca1 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 20:44:33 +0100 Subject: [PATCH 43/81] more functions --- .../colgroup/ColGroupUncompressedArray.java | 13 +++++++++++++ .../compress/colgroup/indexes/ColIndexFactory.java | 2 +- .../runtime/transform/encode/CompressedEncode.java | 1 + 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java index 6f4a6f4d40f..b95cbe05637 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java @@ -21,6 +21,7 @@ import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; /** * Special sideways Compressed column group not supposed to be used outside of the compressed transform encode. @@ -37,4 +38,16 @@ public ColGroupUncompressedArray(Array data, int id, IColIndex colIndexes){ } + @Override + public int getNumValues(){ + return array.size(); + } + + + @Override + public long estimateInMemorySize(){ + // not accurate estimate, but guaranteed larger. + return MatrixBlock.estimateSizeInMemory(array.size(),1,array.size()) + 80; + } + } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java index c9a45e4aeea..53a2c60e98e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java @@ -126,7 +126,7 @@ else if(contiguous) return ArrayIndex.estimateInMemorySizeStatic(nCol); } - public static IColIndex combine(List gs) { + public static IColIndex combine(List gs) { int numCols = 0; for(AColGroup g : gs) numCols += g.getNumCols(); diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 2514cc1a344..880c0b9f782 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -620,6 +620,7 @@ private void combineUncompressed(CompressedMatrixBlock mb) { } private AColGroup combine(List ucg) { + IColIndex combinedCols = ColIndexFactory.combine(ucg); throw new NotImplementedException("Should combine " + ucg.size()); } From 3e983e4814f0c56690ef82e58fb3a849ea43a356 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 20:52:52 +0100 Subject: [PATCH 44/81] simple version --- .../transform/encode/CompressedEncode.java | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 880c0b9f782..eee66e03300 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -37,6 +37,7 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressedArray; import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; @@ -48,6 +49,7 @@ import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; import org.apache.sysds.runtime.compress.estim.sample.SampleEstimatorFactory; import org.apache.sysds.runtime.compress.utils.Util; +import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.ACompressedArray; import org.apache.sysds.runtime.frame.data.columns.Array; @@ -621,7 +623,20 @@ private void combineUncompressed(CompressedMatrixBlock mb) { private AColGroup combine(List ucg) { IColIndex combinedCols = ColIndexFactory.combine(ucg); - throw new NotImplementedException("Should combine " + ucg.size()); + + ucg.sort((a,b) -> Integer.compare(a.id,b.id)); + MatrixBlock ret = new MatrixBlock(in.getNumRows(), combinedCols.size(), false); + ret.allocateDenseBlock(); + DenseBlock db = ret.getDenseBlock(); + for(int i =0; i < in.getNumRows(); i++){ + double[] rval = db.values(i); + int off = db.pos(i); + for(int j = 0; j < combinedCols.size(); j++){ + rval[off + j] = ucg.get(j).array.getAsDouble(i); + } + } + + return ColGroupUncompressed.create(ret, combinedCols); } private void logging(MatrixBlock mb) { From b1f7e6edb7241f9b3c3f25b212bdd9d99e918c2a Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 20:59:40 +0100 Subject: [PATCH 45/81] more logging --- .../compress/colgroup/ColGroupUncompressedArray.java | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java index b95cbe05637..6deb00f4bf3 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java @@ -50,4 +50,10 @@ public long estimateInMemorySize(){ return MatrixBlock.estimateSizeInMemory(array.size(),1,array.size()) + 80; } + + @Override + public String toString(){ + return "UncompressedArrayGroup: " + id + " " + _colIndexes; + } + } From 613f397e7fda50e0349c2cef9276810c3ae1dcf9 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 21:01:42 +0100 Subject: [PATCH 46/81] logging --- .../apache/sysds/runtime/transform/encode/CompressedEncode.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index eee66e03300..ac9f0ed4e84 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -622,6 +622,7 @@ private void combineUncompressed(CompressedMatrixBlock mb) { } private AColGroup combine(List ucg) { + LOG.error(ucg); IColIndex combinedCols = ColIndexFactory.combine(ucg); ucg.sort((a,b) -> Integer.compare(a.id,b.id)); From 4172e64820636791a96a4336d0c8bb1810bbef62 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 21:05:26 +0100 Subject: [PATCH 47/81] safety fix --- .../sysds/runtime/transform/encode/CompressedEncode.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index ac9f0ed4e84..1173c766168 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -616,7 +616,8 @@ private void combineUncompressed(CompressedMatrixBlock mb) { else ret.add(g); } - ret.add(combine(ucg)); + if(ucg.size() > 0) + ret.add(combine(ucg)); nnz.addAndGet(ret.get(ret.size()-1).getNumberNonZeros(in.getNumRows())); mb.allocateColGroupList(ret); } From 4758bce650228da734a439063c2cc375ceedef07 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 21:06:22 +0100 Subject: [PATCH 48/81] only add nnz if combined --- .../sysds/runtime/transform/encode/CompressedEncode.java | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 1173c766168..55f1ef4d3da 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -401,7 +401,7 @@ private AColGroup passThroughNormal(ColumnEncoderComposite c, final IColInde if(a.getValueType() != ValueType.BOOLEAN // if not booleans && (stats == null || !stats.shouldCompress || stats.valueType != a.getValueType())) { - return new ColGroupUncompressedArray(a, c._colID - 1,colIndexes); + return new ColGroupUncompressedArray(a, c._colID - 1, colIndexes); } else { boolean containsNull = a.containsNull(); @@ -616,9 +616,10 @@ private void combineUncompressed(CompressedMatrixBlock mb) { else ret.add(g); } - if(ucg.size() > 0) + if(ucg.size() > 0){ ret.add(combine(ucg)); - nnz.addAndGet(ret.get(ret.size()-1).getNumberNonZeros(in.getNumRows())); + nnz.addAndGet(ret.get(ret.size()-1).getNumberNonZeros(in.getNumRows())); + } mb.allocateColGroupList(ret); } From 495be27f76405996c4eb725816f61fcd7c2bf326 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 21:10:19 +0100 Subject: [PATCH 49/81] fixes --- .../colgroup/ColGroupUncompressedArray.java | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java index 6deb00f4bf3..680513225a5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java @@ -24,35 +24,32 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock; /** - * Special sideways Compressed column group not supposed to be used outside of the compressed transform encode. + * Special sideways compressed column group not supposed to be used outside of the compressed transform encode. */ public class ColGroupUncompressedArray extends ColGroupUncompressed { - + public final Array array; public final int id; // columnID - public ColGroupUncompressedArray(Array data, int id, IColIndex colIndexes){ + public ColGroupUncompressedArray(Array data, int id, IColIndex colIndexes) { super(null, colIndexes); this.array = data; this.id = id; } - - @Override - public int getNumValues(){ + @Override + public int getNumValues() { return array.size(); } - @Override - public long estimateInMemorySize(){ + public long estimateInMemorySize() { // not accurate estimate, but guaranteed larger. - return MatrixBlock.estimateSizeInMemory(array.size(),1,array.size()) + 80; + return MatrixBlock.estimateSizeInMemory(array.size(), 1, array.size()) + 80; } - - @Override - public String toString(){ + @Override + public String toString() { return "UncompressedArrayGroup: " + id + " " + _colIndexes; } From 115fb901ae759f073fe43eab8f105d16fffd271c Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 21:12:19 +0100 Subject: [PATCH 50/81] bad logging add --- .../apache/sysds/runtime/transform/encode/CompressedEncode.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 55f1ef4d3da..55af4cb83ab 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -620,6 +620,7 @@ private void combineUncompressed(CompressedMatrixBlock mb) { ret.add(combine(ucg)); nnz.addAndGet(ret.get(ret.size()-1).getNumberNonZeros(in.getNumRows())); } + LOG.error(ret); mb.allocateColGroupList(ret); } From 30d2d1879c93c6f8e1dfc92da7d2a4942eebd560 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 21:15:57 +0100 Subject: [PATCH 51/81] debugging move --- .../sysds/runtime/transform/encode/CompressedEncode.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 55af4cb83ab..aaa2ac08065 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -615,17 +615,17 @@ private void combineUncompressed(CompressedMatrixBlock mb) { ucg.add((ColGroupUncompressedArray) g); else ret.add(g); - } + } + LOG.error(ucg); + LOG.error(ret); if(ucg.size() > 0){ ret.add(combine(ucg)); nnz.addAndGet(ret.get(ret.size()-1).getNumberNonZeros(in.getNumRows())); } - LOG.error(ret); mb.allocateColGroupList(ret); } private AColGroup combine(List ucg) { - LOG.error(ucg); IColIndex combinedCols = ColIndexFactory.combine(ucg); ucg.sort((a,b) -> Integer.compare(a.id,b.id)); From 17ad997a40a661a65533e4c5c16bfecc20b7b84f Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 21:20:18 +0100 Subject: [PATCH 52/81] why? --- .../runtime/compress/colgroup/ColGroupUncompressedArray.java | 1 + .../apache/sysds/runtime/transform/encode/CompressedEncode.java | 1 + 2 files changed, 2 insertions(+) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java index 680513225a5..68a030c6f3c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java @@ -27,6 +27,7 @@ * Special sideways compressed column group not supposed to be used outside of the compressed transform encode. */ public class ColGroupUncompressedArray extends ColGroupUncompressed { + private static final long serialVersionUID = -825423333043292199L; public final Array array; public final int id; // columnID diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index aaa2ac08065..c1c9450f571 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -401,6 +401,7 @@ private AColGroup passThroughNormal(ColumnEncoderComposite c, final IColInde if(a.getValueType() != ValueType.BOOLEAN // if not booleans && (stats == null || !stats.shouldCompress || stats.valueType != a.getValueType())) { + LOG.error("Create temporary Uncompressed ColumnGroupArray"); return new ColGroupUncompressedArray(a, c._colID - 1, colIndexes); } else { From 847c01bdf617a35023b2bcd696c24f1a537c74a4 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 21:30:17 +0100 Subject: [PATCH 53/81] no longer an extension of Uncompressed --- .../colgroup/ColGroupUncompressedArray.java | 232 +++++++++++++++++- 1 file changed, 230 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java index 68a030c6f3c..b0c90d7adac 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java @@ -19,21 +19,34 @@ package org.apache.sysds.runtime.compress.colgroup; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils.P; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; +import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; +import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.data.SparseBlockMCSR; import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; +import org.apache.sysds.runtime.matrix.operators.CMOperator; +import org.apache.sysds.runtime.matrix.operators.ScalarOperator; +import org.apache.sysds.runtime.matrix.operators.UnaryOperator; /** * Special sideways compressed column group not supposed to be used outside of the compressed transform encode. */ -public class ColGroupUncompressedArray extends ColGroupUncompressed { +public class ColGroupUncompressedArray extends AColGroup { private static final long serialVersionUID = -825423333043292199L; public final Array array; public final int id; // columnID public ColGroupUncompressedArray(Array data, int id, IColIndex colIndexes) { - super(null, colIndexes); + super(colIndexes); this.array = data; this.id = id; } @@ -54,4 +67,219 @@ public String toString() { return "UncompressedArrayGroup: " + id + " " + _colIndexes; } + @Override + public AColGroup copyAndSet(IColIndex colIndexes) { + throw new UnsupportedOperationException("Unimplemented method 'copyAndSet'"); + } + + @Override + public void decompressToDenseBlockTransposed(DenseBlock db, int rl, int ru) { + throw new UnsupportedOperationException("Unimplemented method 'decompressToDenseBlockTransposed'"); + } + + @Override + public void decompressToSparseBlockTransposed(SparseBlockMCSR sb, int nColOut) { + throw new UnsupportedOperationException("Unimplemented method 'decompressToSparseBlockTransposed'"); + } + + @Override + public double getIdx(int r, int colIdx) { + throw new UnsupportedOperationException("Unimplemented method 'getIdx'"); + } + + @Override + public CompressionType getCompType() { + throw new UnsupportedOperationException("Unimplemented method 'getCompType'"); + } + + @Override + protected ColGroupType getColGroupType() { + throw new UnsupportedOperationException("Unimplemented method 'getColGroupType'"); + } + + @Override + public void decompressToDenseBlock(DenseBlock db, int rl, int ru, int offR, int offC) { + throw new UnsupportedOperationException("Unimplemented method 'decompressToDenseBlock'"); + } + + @Override + public void decompressToSparseBlock(SparseBlock sb, int rl, int ru, int offR, int offC) { + throw new UnsupportedOperationException("Unimplemented method 'decompressToSparseBlock'"); + } + + @Override + public AColGroup rightMultByMatrix(MatrixBlock right, IColIndex allCols, int k) { + throw new UnsupportedOperationException("Unimplemented method 'rightMultByMatrix'"); + } + + @Override + public void tsmm(MatrixBlock ret, int nRows) { + throw new UnsupportedOperationException("Unimplemented method 'tsmm'"); + } + + @Override + public void leftMultByMatrixNoPreAgg(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) { + throw new UnsupportedOperationException("Unimplemented method 'leftMultByMatrixNoPreAgg'"); + } + + @Override + public void leftMultByAColGroup(AColGroup lhs, MatrixBlock result, int nRows) { + throw new UnsupportedOperationException("Unimplemented method 'leftMultByAColGroup'"); + } + + @Override + public void tsmmAColGroup(AColGroup other, MatrixBlock result) { + throw new UnsupportedOperationException("Unimplemented method 'tsmmAColGroup'"); + } + + @Override + public AColGroup scalarOperation(ScalarOperator op) { + throw new UnsupportedOperationException("Unimplemented method 'scalarOperation'"); + } + + @Override + public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean isRowSafe) { + throw new UnsupportedOperationException("Unimplemented method 'binaryRowOpLeft'"); + } + + @Override + public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSafe) { + throw new UnsupportedOperationException("Unimplemented method 'binaryRowOpRight'"); + } + + @Override + public void unaryAggregateOperations(AggregateUnaryOperator op, double[] c, int nRows, int rl, int ru) { + throw new UnsupportedOperationException("Unimplemented method 'unaryAggregateOperations'"); + } + + @Override + protected AColGroup sliceSingleColumn(int idx) { + throw new UnsupportedOperationException("Unimplemented method 'sliceSingleColumn'"); + } + + @Override + protected AColGroup sliceMultiColumns(int idStart, int idEnd, IColIndex outputCols) { + throw new UnsupportedOperationException("Unimplemented method 'sliceMultiColumns'"); + } + + @Override + public AColGroup sliceRows(int rl, int ru) { + throw new UnsupportedOperationException("Unimplemented method 'sliceRows'"); + } + + @Override + public double getMin() { + throw new UnsupportedOperationException("Unimplemented method 'getMin'"); + } + + @Override + public double getMax() { + throw new UnsupportedOperationException("Unimplemented method 'getMax'"); + } + + @Override + public double getSum(int nRows) { + throw new UnsupportedOperationException("Unimplemented method 'getSum'"); + } + + @Override + public boolean containsValue(double pattern) { + throw new UnsupportedOperationException("Unimplemented method 'containsValue'"); + } + + @Override + public long getNumberNonZeros(int nRows) { + throw new UnsupportedOperationException("Unimplemented method 'getNumberNonZeros'"); + } + + @Override + public AColGroup replace(double pattern, double replace) { + throw new UnsupportedOperationException("Unimplemented method 'replace'"); + } + + @Override + public void computeColSums(double[] c, int nRows) { + throw new UnsupportedOperationException("Unimplemented method 'computeColSums'"); + } + + @Override + public CM_COV_Object centralMoment(CMOperator op, int nRows) { + throw new UnsupportedOperationException("Unimplemented method 'centralMoment'"); + } + + @Override + public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) { + throw new UnsupportedOperationException("Unimplemented method 'rexpandCols'"); + } + + @Override + public double getCost(ComputationCostEstimator e, int nRows) { + throw new UnsupportedOperationException("Unimplemented method 'getCost'"); + } + + @Override + public AColGroup unaryOperation(UnaryOperator op) { + throw new UnsupportedOperationException("Unimplemented method 'unaryOperation'"); + } + + @Override + public boolean isEmpty() { + throw new UnsupportedOperationException("Unimplemented method 'isEmpty'"); + } + + @Override + public AColGroup append(AColGroup g) { + throw new UnsupportedOperationException("Unimplemented method 'append'"); + } + + @Override + protected AColGroup appendNInternal(AColGroup[] groups, int blen, int rlen) { + throw new UnsupportedOperationException("Unimplemented method 'appendNInternal'"); + } + + @Override + public ICLAScheme getCompressionScheme() { + throw new UnsupportedOperationException("Unimplemented method 'getCompressionScheme'"); + } + + @Override + public AColGroup recompress() { + throw new UnsupportedOperationException("Unimplemented method 'recompress'"); + } + + @Override + public CompressedSizeInfoColGroup getCompressionInfo(int nRow) { + throw new UnsupportedOperationException("Unimplemented method 'getCompressionInfo'"); + } + + @Override + protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { + throw new UnsupportedOperationException("Unimplemented method 'fixColIndexes'"); + } + + @Override + public AColGroup reduceCols() { + throw new UnsupportedOperationException("Unimplemented method 'reduceCols'"); + } + + @Override + public double getSparsity() { + throw new UnsupportedOperationException("Unimplemented method 'getSparsity'"); + } + + @Override + protected void sparseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new UnsupportedOperationException("Unimplemented method 'sparseSelection'"); + } + + @Override + protected void denseSelection(MatrixBlock selection, P[] points, MatrixBlock ret, int rl, int ru) { + throw new UnsupportedOperationException("Unimplemented method 'denseSelection'"); + } + + @Override + public AColGroup[] splitReshape(int multiplier, int nRow, int nColOrg) { + throw new UnsupportedOperationException("Unimplemented method 'splitReshape'"); + } + } From 099779ab17c0e6d6f4f4136ea0b43210bc3768c4 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 21:32:51 +0100 Subject: [PATCH 54/81] copy and set --- .../runtime/compress/colgroup/ColGroupUncompressedArray.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java index b0c90d7adac..31e29341645 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressedArray.java @@ -69,7 +69,7 @@ public String toString() { @Override public AColGroup copyAndSet(IColIndex colIndexes) { - throw new UnsupportedOperationException("Unimplemented method 'copyAndSet'"); + return new ColGroupUncompressedArray(array, id, colIndexes); } @Override From 56c2bec4a1e642dccd2c848913aaf588d0bcaa64 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 21:36:41 +0100 Subject: [PATCH 55/81] count nnz --- .../apache/sysds/runtime/transform/encode/CompressedEncode.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index c1c9450f571..e89aba93a58 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -641,6 +641,8 @@ private AColGroup combine(List ucg) { } } + ret.recomputeNonZeros(k); + return ColGroupUncompressed.create(ret, combinedCols); } From 1a150d34b5eeb6bd76548114add960a1edd0e3f9 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 21:37:27 +0100 Subject: [PATCH 56/81] remove logging --- .../sysds/runtime/transform/encode/CompressedEncode.java | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index e89aba93a58..fa8911359b6 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -401,7 +401,6 @@ private AColGroup passThroughNormal(ColumnEncoderComposite c, final IColInde if(a.getValueType() != ValueType.BOOLEAN // if not booleans && (stats == null || !stats.shouldCompress || stats.valueType != a.getValueType())) { - LOG.error("Create temporary Uncompressed ColumnGroupArray"); return new ColGroupUncompressedArray(a, c._colID - 1, colIndexes); } else { @@ -617,8 +616,6 @@ private void combineUncompressed(CompressedMatrixBlock mb) { else ret.add(g); } - LOG.error(ucg); - LOG.error(ret); if(ucg.size() > 0){ ret.add(combine(ucg)); nnz.addAndGet(ret.get(ret.size()-1).getNumberNonZeros(in.getNumRows())); From fbee61570f50700d7ed8c7799750301835fe07d5 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 21:50:32 +0100 Subject: [PATCH 57/81] parallel putinto --- .../transform/encode/CompressedEncode.java | 59 ++++++++++++++----- 1 file changed, 45 insertions(+), 14 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index fa8911359b6..a9f2d0b60ca 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -606,7 +606,7 @@ private void estimateRCDMapSize(ColumnEncoderComposite c) { c._estNumDistincts = estDistCount; } - private void combineUncompressed(CompressedMatrixBlock mb) { + private void combineUncompressed(CompressedMatrixBlock mb) throws InterruptedException, ExecutionException { List ucg = new ArrayList<>(); List ret = new ArrayList<>(); @@ -615,34 +615,65 @@ private void combineUncompressed(CompressedMatrixBlock mb) { ucg.add((ColGroupUncompressedArray) g); else ret.add(g); - } - if(ucg.size() > 0){ + } + if(ucg.size() > 0) { ret.add(combine(ucg)); - nnz.addAndGet(ret.get(ret.size()-1).getNumberNonZeros(in.getNumRows())); + nnz.addAndGet(ret.get(ret.size() - 1).getNumberNonZeros(in.getNumRows())); } mb.allocateColGroupList(ret); } - private AColGroup combine(List ucg) { + private AColGroup combine(List ucg) throws InterruptedException, ExecutionException { IColIndex combinedCols = ColIndexFactory.combine(ucg); - ucg.sort((a,b) -> Integer.compare(a.id,b.id)); + ucg.sort((a, b) -> Integer.compare(a.id, b.id)); MatrixBlock ret = new MatrixBlock(in.getNumRows(), combinedCols.size(), false); ret.allocateDenseBlock(); - DenseBlock db = ret.getDenseBlock(); - for(int i =0; i < in.getNumRows(); i++){ - double[] rval = db.values(i); - int off = db.pos(i); - for(int j = 0; j < combinedCols.size(); j++){ - rval[off + j] = ucg.get(j).array.getAsDouble(i); - } - } + final DenseBlock db = ret.getDenseBlock(); + final int nrow = in.getNumRows(); + final int ncol = combinedCols.size(); + if(isParallel() && (long) nrow * ncol > 10000 && nrow > 512) + parallelPutInto(ucg, db, nrow, ncol); + else + putInto(ucg, db, 0, nrow, 0, ncol); ret.recomputeNonZeros(k); return ColGroupUncompressed.create(ret, combinedCols); } + private void parallelPutInto(List ucg, DenseBlock db, int nrow, int ncol) + throws InterruptedException, ExecutionException { + List> tasks = new ArrayList<>(); + + final int iblk = Math.max(512, nrow / k); + final int jblk = Math.min(128, ncol); + for(int i = 0; i < nrow; i += iblk) { + int si = i; + int ei = Math.min(nrow, iblk + i); + for(int j = 0; j < ncol; j += jblk) { + int sj = j; + int ej = Math.min(ncol, jblk + j); + tasks.add(pool.submit(() -> { + putInto(ucg, db, si, ei, sj, ej); + })); + } + } + + for(Future t : tasks) + t.get(); + } + + private void putInto(List ucg, DenseBlock db, int il, int iu, int jl, int ju) { + for(int i = il; i < iu; i++) { + final double[] rval = db.values(i); + final int off = db.pos(i); + for(int j = jl; j < ju; j++) { + rval[off + j] = ucg.get(j).array.getAsDouble(i); + } + } + } + private void logging(MatrixBlock mb) { if(LOG.isDebugEnabled()) { LOG.debug(String.format("Uncompressed transform encode Dense size: %16d", mb.estimateSizeDenseInMemory())); From d39895ab31720b92c7c74d9097e0cd7b56da44b3 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 21:53:52 +0100 Subject: [PATCH 58/81] count nnz --- .../transform/encode/CompressedEncode.java | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index a9f2d0b60ca..3e4606b30a0 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -632,19 +632,19 @@ private AColGroup combine(List ucg) throws Interrupte final DenseBlock db = ret.getDenseBlock(); final int nrow = in.getNumRows(); final int ncol = combinedCols.size(); + final long combinedNNZ; if(isParallel() && (long) nrow * ncol > 10000 && nrow > 512) - parallelPutInto(ucg, db, nrow, ncol); + combinedNNZ = parallelPutInto(ucg, db, nrow, ncol); else - putInto(ucg, db, 0, nrow, 0, ncol); - - ret.recomputeNonZeros(k); + combinedNNZ = putInto(ucg, db, 0, nrow, 0, ncol); + nnz.addAndGet(combinedNNZ); return ColGroupUncompressed.create(ret, combinedCols); } - private void parallelPutInto(List ucg, DenseBlock db, int nrow, int ncol) + private long parallelPutInto(List ucg, DenseBlock db, int nrow, int ncol) throws InterruptedException, ExecutionException { - List> tasks = new ArrayList<>(); + List> tasks = new ArrayList<>(); final int iblk = Math.max(512, nrow / k); final int jblk = Math.min(128, ncol); @@ -655,23 +655,26 @@ private void parallelPutInto(List ucg, DenseBlock db, int sj = j; int ej = Math.min(ncol, jblk + j); tasks.add(pool.submit(() -> { - putInto(ucg, db, si, ei, sj, ej); + return putInto(ucg, db, si, ei, sj, ej); })); } } - - for(Future t : tasks) - t.get(); + long nnz = 0; + for(Future t : tasks) + nnz += t.get(); + return nnz; } - private void putInto(List ucg, DenseBlock db, int il, int iu, int jl, int ju) { + private long putInto(List ucg, DenseBlock db, int il, int iu, int jl, int ju) { + long nnz = 0; for(int i = il; i < iu; i++) { final double[] rval = db.values(i); final int off = db.pos(i); for(int j = jl; j < ju; j++) { - rval[off + j] = ucg.get(j).array.getAsDouble(i); + nnz += (rval[off + j] = ucg.get(j).array.getAsDouble(i)) == 0.0 ? 1 : 0; } } + return nnz; } private void logging(MatrixBlock mb) { From 578308a1fa0315aacf22d2ff2ee7f95ad5096caf Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 21:56:30 +0100 Subject: [PATCH 59/81] set nnz --- .../apache/sysds/runtime/transform/encode/CompressedEncode.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 3e4606b30a0..2acccfe2473 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -639,6 +639,7 @@ private AColGroup combine(List ucg) throws Interrupte combinedNNZ = putInto(ucg, db, 0, nrow, 0, ncol); nnz.addAndGet(combinedNNZ); + ret.setNonZeros(combinedNNZ); return ColGroupUncompressed.create(ret, combinedCols); } From 6cd34c264d297d512f6363f28cbbc8f8980eb299 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 22:56:48 +0100 Subject: [PATCH 60/81] better parallelization --- .../transform/encode/CompressedEncode.java | 120 ++++++++++-------- 1 file changed, 68 insertions(+), 52 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 2acccfe2473..550da6fca4a 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -82,6 +82,8 @@ public class CompressedEncode { private final AtomicLong nnz = new AtomicLong(); + private static final IColIndex SINGLE_COL_TMP_INDEX = ColIndexFactory.create(1); + private CompressedEncode(MultiColumnEncoder enc, FrameBlock in, int k) { this.enc = enc; this.in = in; @@ -107,8 +109,6 @@ private MatrixBlock apply() throws Exception { final List groups = isParallel() ? multiThread(encoders) : singleThread(encoders); final int cols = shiftGroups(groups); final CompressedMatrixBlock mb = new CompressedMatrixBlock(in.getNumRows(), cols, -1, false, groups); - - combineUncompressed(mb); mb.setNonZeros(nnz.get()); logging(mb); return mb; @@ -125,16 +125,36 @@ private boolean isParallel() { private List singleThread(List encoders) throws Exception { List groups = new ArrayList<>(encoders.size()); - for(ColumnEncoderComposite c : encoders) - groups.add(encode(c)); + List ucg = new ArrayList<>(); + for(ColumnEncoderComposite c : encoders) { + AColGroup g = encode(c); + if(g instanceof ColGroupUncompressedArray) + ucg.add((ColGroupUncompressedArray) g); + else + groups.add(g); + } + if(ucg.size() > 0) { + groups.add(combine(ucg)); + } return groups; } private List multiThread(List encoders) throws Exception { final List> tasks = new ArrayList<>(encoders.size()); - for(ColumnEncoderComposite c : encoders) - tasks.add(pool.submit(() -> encode(c))); + final List> ucgTasks = new ArrayList<>(); + for(ColumnEncoderComposite c : encoders) { + + Array a = in.getColumn(c._colID - 1); + if(c.isPassThrough() && !(a instanceof ACompressedArray) && uncompressedPassThrough(a)) + ucgTasks.add(pool.submit(() -> encode(c))); + else + tasks.add(pool.submit(() -> encode(c))); + } final List groups = new ArrayList<>(encoders.size()); + if(!ucgTasks.isEmpty()) { + groups.add(combineFutures(ucgTasks)); + } + for(Future t : tasks) groups.add(t.get()); return groups; @@ -383,44 +403,48 @@ private ADictionary createRecodeDictionary(boolean containsNull, int domain) { @SuppressWarnings("unchecked") private AColGroup passThrough(ColumnEncoderComposite c) throws Exception { - - final IColIndex colIndexes = ColIndexFactory.create(1); - final int colId = c._colID; - final Array a = (Array) in.getColumn(colId - 1); + final int colId = c._colID - 1; + final Array a = (Array) in.getColumn(colId); if(a instanceof ACompressedArray) - return passThroughCompressed(colIndexes, a); + return passThroughCompressed(a); + else if(uncompressedPassThrough(a)) + return new ColGroupUncompressedArray(a, colId, SINGLE_COL_TMP_INDEX); else - return passThroughNormal(c, colIndexes, a); + return compressingPassThrough(c, a); } - private AColGroup passThroughNormal(ColumnEncoderComposite c, final IColIndex colIndexes, final Array a) + private AColGroup compressingPassThrough(ColumnEncoderComposite c, final Array a) throws InterruptedException, ExecutionException, Exception { - // Take a small sample - ArrayCompressionStatistics stats = !inputContainsCompressed ? // - a.statistics(Math.min(1000, a.size())) : null; + boolean containsNull = a.containsNull(); + estimateRCDMapSize(c); + HashMapToInt map = (HashMapToInt) a.getRecodeMap(c._estNumDistincts, pool, k / in.getNumColumns()); + double[] vals = new double[map.size() + (containsNull ? 1 : 0)]; + if(containsNull) + vals[map.size()] = Double.NaN; + ValueType t = a.getValueType(); + map.forEach((k, v) -> vals[v.intValue() - 1] = UtilFunctions.objectToDouble(t, k)); + ADictionary d = Dictionary.create(vals); + AMapToData m = createMappingAMapToData(a, map, containsNull); + AColGroup ret = ColGroupDDC.create(SINGLE_COL_TMP_INDEX, d, m, null); + nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); + return ret; + } + + private boolean uncompressedPassThrough(final Array a) { + + if(a.getValueType() != ValueType.BOOLEAN) { - if(a.getValueType() != ValueType.BOOLEAN // if not booleans - && (stats == null || !stats.shouldCompress || stats.valueType != a.getValueType())) { - return new ColGroupUncompressedArray(a, c._colID - 1, colIndexes); + ArrayCompressionStatistics stats = !inputContainsCompressed ? // + a.statistics(Math.min(1000, a.size())) : null; + return stats == null // if some columns already are compressed then most likely we do not need to + || !stats.shouldCompress // if we should compress ... lets + || stats.valueType != a.getValueType(); // if the compression says change value type, then do not do it. } - else { - boolean containsNull = a.containsNull(); - estimateRCDMapSize(c); - HashMapToInt map = (HashMapToInt) a.getRecodeMap(c._estNumDistincts, pool, k / in.getNumColumns()); - double[] vals = new double[map.size() + (containsNull ? 1 : 0)]; - if(containsNull) - vals[map.size()] = Double.NaN; - ValueType t = a.getValueType(); - map.forEach((k, v) -> vals[v.intValue() - 1] = UtilFunctions.objectToDouble(t, k)); - ADictionary d = Dictionary.create(vals); - AMapToData m = createMappingAMapToData(a, map, containsNull); - AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); - nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); - return ret; - } - } - - private AColGroup passThroughCompressed(final IColIndex colIndexes, final Array a) { + + return false;// if not booleans + } + + private AColGroup passThroughCompressed(final Array a) { // only DDC possible currently. DDCArray aDDC = (DDCArray) a; Array dict = aDDC.getDict(); @@ -433,7 +457,7 @@ private AColGroup passThroughCompressed(final IColIndex colIndexes, final Ar vals[i] = dict.getAsDouble(i); ADictionary d = Dictionary.create(vals); - AColGroup ret = ColGroupDDC.create(colIndexes, d, aDDC.getMap(), null); + AColGroup ret = ColGroupDDC.create(SINGLE_COL_TMP_INDEX, d, aDDC.getMap(), null); nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); return ret; @@ -606,21 +630,12 @@ private void estimateRCDMapSize(ColumnEncoderComposite c) { c._estNumDistincts = estDistCount; } - private void combineUncompressed(CompressedMatrixBlock mb) throws InterruptedException, ExecutionException { - - List ucg = new ArrayList<>(); - List ret = new ArrayList<>(); - for(AColGroup g : mb.getColGroups()) { - if(g instanceof ColGroupUncompressedArray) - ucg.add((ColGroupUncompressedArray) g); - else - ret.add(g); + private AColGroup combineFutures(List> ucgTasks) throws InterruptedException, ExecutionException { + List ucg = new ArrayList<>(ucgTasks.size()); + for(Future g : ucgTasks) { + ucg.add((ColGroupUncompressedArray) g.get()); } - if(ucg.size() > 0) { - ret.add(combine(ucg)); - nnz.addAndGet(ret.get(ret.size() - 1).getNumberNonZeros(in.getNumRows())); - } - mb.allocateColGroupList(ret); + return combine(ucg); } private AColGroup combine(List ucg) throws InterruptedException, ExecutionException { @@ -640,6 +655,7 @@ private AColGroup combine(List ucg) throws Interrupte nnz.addAndGet(combinedNNZ); ret.setNonZeros(combinedNNZ); + return ColGroupUncompressed.create(ret, combinedCols); } From 2e7b3010722dc4be7ca801501dee9ce018265239 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 14 Jan 2025 22:59:54 +0100 Subject: [PATCH 61/81] SINGLE COL TMP INDEX --- .../transform/encode/CompressedEncode.java | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 550da6fca4a..ca7731b1dcf 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -236,13 +236,12 @@ private AColGroup bin(ColumnEncoderComposite c) throws InterruptedException, Exe final ColumnEncoderBin b = (ColumnEncoderBin) r.get(0); b.build(in); final boolean containsNull = b.containsNull; - final IColIndex colIndexes = ColIndexFactory.create(1); ADictionary d = createIncrementingVector(b._numBin, containsNull); final AMapToData m; m = binEncode(a, b, containsNull); - AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); + AColGroup ret = ColGroupDDC.create(SINGLE_COL_TMP_INDEX, d, m, null); nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); return ret; } @@ -369,14 +368,11 @@ private AColGroup recode(ColumnEncoderComposite c) throws Exception { boolean containsNull = a.containsNull(); int domain = map.size(); - // int domain = c.getDomainSize(); - IColIndex colIndexes = ColIndexFactory.create(1); - if(domain == 0 && containsNull) { - return new ColGroupEmpty(colIndexes); - } + if(domain == 0 && containsNull) + return new ColGroupEmpty(SINGLE_COL_TMP_INDEX); if(domain == 1 && !containsNull) { nnz.addAndGet(in.getNumRows()); - return ColGroupConst.create(colIndexes, new double[] {1}); + return ColGroupConst.create(SINGLE_COL_TMP_INDEX, new double[] {1}); } ADictionary d = createRecodeDictionary(containsNull, domain); @@ -384,7 +380,7 @@ private AColGroup recode(ColumnEncoderComposite c) throws Exception { List r = c.getEncoders(); r.set(0, new ColumnEncoderRecode(colId, (HashMapToInt) map)); - AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); + AColGroup ret = ColGroupDDC.create(SINGLE_COL_TMP_INDEX, d, m, null); nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); return ret; @@ -552,13 +548,11 @@ private AColGroup hash(ColumnEncoderComposite c) { ColumnEncoderFeatureHash CEHash = (ColumnEncoderFeatureHash) c.getEncoders().get(0); int domain = (int) CEHash.getK(); boolean nulls = a.containsNull(); - IColIndex colIndexes = ColIndexFactory.create(0, 1); - if(domain == 0 && nulls) { - return new ColGroupEmpty(colIndexes); - } + if(domain == 0 && nulls) + return new ColGroupEmpty(SINGLE_COL_TMP_INDEX); if(domain == 1 && !nulls) { nnz.addAndGet(in.getNumRows()); - return ColGroupConst.create(colIndexes, new double[] {1}); + return ColGroupConst.create(SINGLE_COL_TMP_INDEX, new double[] {1}); } MatrixBlock incrementing = new MatrixBlock(domain + (nulls ? 1 : 0), 1, false); @@ -570,7 +564,7 @@ private AColGroup hash(ColumnEncoderComposite c) { ADictionary d = MatrixBlockDictionary.create(incrementing); AMapToData m = createHashMappingAMapToData(a, domain, nulls); - AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null); + AColGroup ret = ColGroupDDC.create(SINGLE_COL_TMP_INDEX, d, m, null); nnz.addAndGet(ret.getNumberNonZeros(in.getNumRows())); return ret; } From a449d505141b5941847eb7e4c6f7aeb995f1cf82 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 00:33:17 +0100 Subject: [PATCH 62/81] try ? --- .../transform/encode/CompressedEncode.java | 39 ++++++++++++++----- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index ca7731b1dcf..8275d294822 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -48,6 +48,7 @@ import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; import org.apache.sysds.runtime.compress.estim.sample.SampleEstimatorFactory; +import org.apache.sysds.runtime.compress.utils.IntArrayList; import org.apache.sysds.runtime.compress.utils.Util; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; @@ -133,9 +134,9 @@ private List singleThread(List encoders) thro else groups.add(g); } - if(ucg.size() > 0) { + if(ucg.size() > 0) groups.add(combine(ucg)); - } + return groups; } @@ -152,7 +153,7 @@ private List multiThread(List encoders) throw } final List groups = new ArrayList<>(encoders.size()); if(!ucgTasks.isEmpty()) { - groups.add(combineFutures(ucgTasks)); + tasks.add(pool.submit(() -> combineFutures(ucgTasks))); } for(Future t : tasks) @@ -167,12 +168,32 @@ private List multiThread(List encoders) throw * @return The total number of columns contained. */ private int shiftGroups(List groups) { - int cols = groups.get(0).getColIndices().size(); - for(int i = 1; i < groups.size(); i++) { - groups.set(i, groups.get(i).shiftColIndices(cols)); - cols += groups.get(i).getColIndices().size(); + + int curCols = 0; + int curGroup = 0; + final List encoders = enc.getColumnEncoders(); + final IntArrayList ucCols = new IntArrayList(); + + // ColIndexFactory.create + for(int i = 0; i < encoders.size(); i++ ){ + // for each encoder ... + ColumnEncoderComposite c = encoders.get(i); + Array a = in.getColumn(c._colID - 1); + if(c.isPassThrough() && !(a instanceof ACompressedArray) && uncompressedPassThrough(a)){ + ucCols.appendValue(curCols++); + } + else { + AColGroup g = groups.get(curGroup); + groups.set( curGroup, g.shiftColIndices(curCols)); + curCols += g.getColIndices().size(); + } + } + if( ucCols.size() > 0){ + int i = groups.size()-1; + AColGroup g =groups.get(i); + groups.set(i, g.copyAndSet(ColIndexFactory.create(ucCols))); } - return cols; + return curCols; } private AColGroup encode(ColumnEncoderComposite c) throws Exception { @@ -633,7 +654,7 @@ private AColGroup combineFutures(List> ucgTasks) throws Interr } private AColGroup combine(List ucg) throws InterruptedException, ExecutionException { - IColIndex combinedCols = ColIndexFactory.combine(ucg); + IColIndex combinedCols = ColIndexFactory.create(ucg.size()); ucg.sort((a, b) -> Integer.compare(a.id, b.id)); MatrixBlock ret = new MatrixBlock(in.getNumRows(), combinedCols.size(), false); From 110dec802bb31820ad66b57f3be88cf3db5ec1e8 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 00:44:10 +0100 Subject: [PATCH 63/81] timing of combining --- .../sysds/runtime/transform/encode/CompressedEncode.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 8275d294822..277688e51f2 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -180,6 +180,8 @@ private int shiftGroups(List groups) { ColumnEncoderComposite c = encoders.get(i); Array a = in.getColumn(c._colID - 1); if(c.isPassThrough() && !(a instanceof ACompressedArray) && uncompressedPassThrough(a)){ + // if this encoder was part of the uncompressed encoders. + // do not shift the column indexes because we combined all uncompressed columnGroups. ucCols.appendValue(curCols++); } else { @@ -654,6 +656,7 @@ private AColGroup combineFutures(List> ucgTasks) throws Interr } private AColGroup combine(List ucg) throws InterruptedException, ExecutionException { + final Timing t = new Timing(); IColIndex combinedCols = ColIndexFactory.create(ucg.size()); ucg.sort((a, b) -> Integer.compare(a.id, b.id)); @@ -670,7 +673,9 @@ private AColGroup combine(List ucg) throws Interrupte nnz.addAndGet(combinedNNZ); ret.setNonZeros(combinedNNZ); - + if(LOG.isDebugEnabled()) + LOG.debug("Combining of : " + ucg.size() + " uncompressed columns Time:" + t); + return ColGroupUncompressed.create(ret, combinedCols); } From 4c0f7eb946e4951fc433e37d5616e288a7e85d04 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 00:46:05 +0100 Subject: [PATCH 64/81] compressed size ... even if abort --- .../sysds/runtime/compress/CompressedMatrixBlockFactory.java | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java index 90505aa6004..83b21ab2359 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java @@ -348,6 +348,7 @@ private void classifyPhase() { LOG.info("Threshold was set to : " + threshold + " but it was above original " + _stats.originalCost); LOG.info("Original size : " + _stats.originalSize); LOG.info("single col size : " + _stats.estimatedSizeCols); + LOG.debug(String.format("--compressed size: %16d", _stats.originalSize)); if(!(costEstimator instanceof MemoryCostEstimator)) { LOG.info("original cost : " + _stats.originalCost); LOG.info("single col cost : " + _stats.estimatedCostCols); From 51dc8235bc0340828212bfb22307bde0b253eb03 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 01:18:34 +0100 Subject: [PATCH 65/81] parallel --- .../transform/encode/CompressedEncode.java | 80 ++++++++++++++----- 1 file changed, 62 insertions(+), 18 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 277688e51f2..21db73d6a62 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -54,7 +54,9 @@ import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.ACompressedArray; import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; import org.apache.sysds.runtime.frame.data.columns.DDCArray; +import org.apache.sysds.runtime.frame.data.columns.DoubleArray; import org.apache.sysds.runtime.frame.data.columns.HashMapToInt; import org.apache.sysds.runtime.frame.data.compress.ArrayCompressionStatistics; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -136,7 +138,7 @@ private List singleThread(List encoders) thro } if(ucg.size() > 0) groups.add(combine(ucg)); - + return groups; } @@ -175,24 +177,24 @@ private int shiftGroups(List groups) { final IntArrayList ucCols = new IntArrayList(); // ColIndexFactory.create - for(int i = 0; i < encoders.size(); i++ ){ + for(int i = 0; i < encoders.size(); i++) { // for each encoder ... ColumnEncoderComposite c = encoders.get(i); Array a = in.getColumn(c._colID - 1); - if(c.isPassThrough() && !(a instanceof ACompressedArray) && uncompressedPassThrough(a)){ + if(c.isPassThrough() && !(a instanceof ACompressedArray) && uncompressedPassThrough(a)) { // if this encoder was part of the uncompressed encoders. // do not shift the column indexes because we combined all uncompressed columnGroups. ucCols.appendValue(curCols++); } else { AColGroup g = groups.get(curGroup); - groups.set( curGroup, g.shiftColIndices(curCols)); + groups.set(curGroup, g.shiftColIndices(curCols)); curCols += g.getColIndices().size(); } } - if( ucCols.size() > 0){ - int i = groups.size()-1; - AColGroup g =groups.get(i); + if(ucCols.size() > 0) { + int i = groups.size() - 1; + AColGroup g = groups.get(i); groups.set(i, g.copyAndSet(ColIndexFactory.create(ucCols))); } return curCols; @@ -463,17 +465,13 @@ private boolean uncompressedPassThrough(final Array a) { return false;// if not booleans } - private AColGroup passThroughCompressed(final Array a) { + private AColGroup passThroughCompressed(final Array a) throws InterruptedException, ExecutionException { // only DDC possible currently. DDCArray aDDC = (DDCArray) a; Array dict = aDDC.getDict(); - double[] vals = new double[dict.size()]; - if(a.containsNull()) - for(int i = 0; i < dict.size(); i++) - vals[i] = dict.getAsNaNDouble(i); - else - for(int i = 0; i < dict.size(); i++) - vals[i] = dict.getAsDouble(i); + final int dSize = dict.size(); + + final double[] vals = passThroughCompressedCreateDict(a, dict, dSize); ADictionary d = Dictionary.create(vals); AColGroup ret = ColGroupDDC.create(SINGLE_COL_TMP_INDEX, d, aDDC.getMap(), null); @@ -482,6 +480,52 @@ private AColGroup passThroughCompressed(final Array a) { return ret; } + private double[] passThroughCompressedCreateDict(final Array a, Array dict, final int dSize) throws InterruptedException, ExecutionException { + final double[] vals; + final boolean nulls = a.containsNull(); + if(dict.getValueType() == ValueType.FP64 && !nulls) { + DoubleArray converted = ((DoubleArray) dict); + vals = converted.get(); + } + else if(!nulls) { + DoubleArray converted = ArrayFactory.create(new double[dSize]); + passThroughTransferNoNulls(dict, dSize, converted); + vals = converted.get(); + } + else { + vals = passThroughTransferNulls(dict, dSize); + } + return vals; + } + + private double[] passThroughTransferNulls(Array dict, final int dSize) { + final double[] vals; + vals = new double[dSize]; + for(int i = 0; i < dSize; i++) { + vals[i] = dict.getAsNaNDouble(i); + } + return vals; + } + + private void passThroughTransferNoNulls(Array dict, final int dSize, DoubleArray converted) throws InterruptedException, ExecutionException { + if(isParallel() && dSize > 10000){ + final int blkz = Math.min(10000 , (dSize + k) / k); + final List> tasks = new ArrayList<>(); + for(int i = 0; i < dSize ; i += blkz){ + int si = i; + int ei = Math.min(dSize, i + blkz); + tasks.add(pool.submit(() -> { + dict.changeType(converted, si, ei); + })); + } + for(Future t : tasks) + t.get(); + } + else { + dict.changeType(converted, 0, dSize); + } + } + private AMapToData createMappingAMapToData(Array a, HashMapToInt map, boolean containsNull) throws Exception { final int si = map.size(); @@ -674,8 +718,8 @@ private AColGroup combine(List ucg) throws Interrupte nnz.addAndGet(combinedNNZ); ret.setNonZeros(combinedNNZ); if(LOG.isDebugEnabled()) - LOG.debug("Combining of : " + ucg.size() + " uncompressed columns Time:" + t); - + LOG.debug("Combining of : " + ucg.size() + " uncompressed columns Time: " + t.stop()); + return ColGroupUncompressed.create(ret, combinedCols); } @@ -708,7 +752,7 @@ private long putInto(List ucg, DenseBlock db, int il, final double[] rval = db.values(i); final int off = db.pos(i); for(int j = jl; j < ju; j++) { - nnz += (rval[off + j] = ucg.get(j).array.getAsDouble(i)) == 0.0 ? 1 : 0; + nnz += (rval[off + j] = ucg.get(j).array.getAsNaNDouble(i)) == 0.0 ? 1 : 0; } } return nnz; From 7e746bfd1aa432bc4a0ef5d98e31952fb0d480c7 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 01:21:06 +0100 Subject: [PATCH 66/81] more JIT --- .../transform/encode/CompressedEncode.java | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 21db73d6a62..c2733812e9a 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -480,7 +480,8 @@ private AColGroup passThroughCompressed(final Array a) throws Interrupted return ret; } - private double[] passThroughCompressedCreateDict(final Array a, Array dict, final int dSize) throws InterruptedException, ExecutionException { + private double[] passThroughCompressedCreateDict(final Array a, Array dict, final int dSize) + throws InterruptedException, ExecutionException { final double[] vals; final boolean nulls = a.containsNull(); if(dict.getValueType() == ValueType.FP64 && !nulls) { @@ -507,11 +508,12 @@ private double[] passThroughTransferNulls(Array dict, final int dSize) { return vals; } - private void passThroughTransferNoNulls(Array dict, final int dSize, DoubleArray converted) throws InterruptedException, ExecutionException { - if(isParallel() && dSize > 10000){ - final int blkz = Math.min(10000 , (dSize + k) / k); + private void passThroughTransferNoNulls(Array dict, final int dSize, DoubleArray converted) + throws InterruptedException, ExecutionException { + if(isParallel() && dSize > 10000) { + final int blkz = Math.min(10000, (dSize + k) / k); final List> tasks = new ArrayList<>(); - for(int i = 0; i < dSize ; i += blkz){ + for(int i = 0; i < dSize; i += blkz) { int si = i; int ei = Math.min(dSize, i + blkz); tasks.add(pool.submit(() -> { @@ -746,14 +748,20 @@ private long parallelPutInto(List ucg, DenseBlock db, return nnz; } - private long putInto(List ucg, DenseBlock db, int il, int iu, int jl, int ju) { + private final long putInto(List ucg, DenseBlock db, int il, int iu, int jl, int ju) { long nnz = 0; for(int i = il; i < iu; i++) { final double[] rval = db.values(i); final int off = db.pos(i); - for(int j = jl; j < ju; j++) { - nnz += (rval[off + j] = ucg.get(j).array.getAsNaNDouble(i)) == 0.0 ? 1 : 0; - } + nnz = putIntoRowBlock(ucg, jl, ju, nnz, i, rval, off); + } + return nnz; + } + + private final long putIntoRowBlock(List ucg, int jl, int ju, long nnz, int i, + final double[] rval, final int off) { + for(int j = jl; j < ju; j++) { + nnz += (rval[off + j] = ucg.get(j).array.getAsNaNDouble(i)) == 0.0 ? 1 : 0; } return nnz; } From c52ad51def87453757ca9a53b3c389f37e4f8e08 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 01:48:09 +0100 Subject: [PATCH 67/81] keynull --- .../sysds/runtime/frame/data/columns/Array.java | 8 ++++---- .../runtime/frame/data/columns/DDCArray.java | 6 ++---- .../runtime/frame/data/columns/FloatArray.java | 16 ---------------- .../runtime/frame/data/columns/HashMapToInt.java | 12 ++++++++---- 4 files changed, 14 insertions(+), 28 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index 5b61de8f9bd..9cdd3012e7e 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -970,12 +970,12 @@ protected boolean earlyAbortEstimateDistinct(int distinctFound, int samplesTaken && distinctFound * 100 >= samplesTaken * 60; // More than 60 % distinct } - protected int setAndAddToDict(Map rcd, AMapToData m, int i, Integer id) { + protected int setAndAddToDict(HashMapToInt rcd, AMapToData m, int i, Integer id) { final T val = getInternal(i); - final Integer v = rcd.get(val); - if(v == null) { + final int v = rcd.putIfAbsentI(val, id); + if(v == -1) { m.set(i, id); - rcd.put(val, id++); + id++; } else m.set(i, v); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java index 69bfa38e7ff..2505ed6c275 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java @@ -22,8 +22,6 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; -import java.util.HashMap; -import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -141,7 +139,7 @@ public static Array compressToDDC(Array arr, int estimateUnique) { final int t = getTryThreshold(arr.getValueType(), s, arr.getInMemorySize()); // One pass algorithm... - final Map rcd = new HashMap<>(); + final HashMapToInt rcd = new HashMapToInt(estimateUnique); // map should guarantee to be able to hold the distinct values. final AMapToData m = MapToFactory.create(s, Math.min(t, estimateUnique)); Integer id = 0; @@ -157,7 +155,7 @@ public static Array compressToDDC(Array arr, int estimateUnique) { // Allocate the correct dictionary output final Array ar; - if(rcd.keySet().contains(null)) + if(rcd.containsKey(null)) ar = (Array) ArrayFactory.allocateOptional(arr.getValueType(), rcd.size()); else ar = (Array) ArrayFactory.allocate(arr.getValueType(), rcd.size()); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java index d586c2f32a8..d0ab7a56305 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/FloatArray.java @@ -25,11 +25,9 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.Arrays; -import java.util.Map; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.frame.data.columns.ArrayFactory.FrameArrayType; import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.util.UtilFunctions; @@ -393,20 +391,6 @@ public boolean possiblyContainsNaN() { return true; } - @Override - protected int setAndAddToDict(Map rcd, AMapToData m, int i, Integer id) { - // JIT. - final Float val = _data[i]; - final Integer v = rcd.get(val); - if(v == null) { - m.set(i, id); - rcd.put(val, id++); - } - else - m.set(i, v); - return id; - } - @Override public String toString() { StringBuilder sb = new StringBuilder(_size * 5 + 2); diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java index 3772174a00f..7f377dd633a 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java @@ -87,9 +87,10 @@ public Integer get(Object key) { public int getI(K key) { final int ix = hash(key); Node b = buckets[ix]; + final boolean keyNull = key == null; if(b != null) { do { - if(b.key.equals(key)) + if((keyNull && b.key == null) || (b.key != null && b.key.equals(key))) return b.value; } while((b = b.next) != null); @@ -98,6 +99,8 @@ public int getI(K key) { } public int hash(K key) { + if(key == null) + return 0; return Math.abs(key.hashCode()) % buckets.length; } @@ -130,8 +133,9 @@ public int putIfAbsentI(K key, int value) { private int putIfAbsentBucket(int ix, K key, int value) { Node b = buckets[ix]; + final boolean keyNull = key == null; while(true) { - if(b.key.equals(key)) + if((keyNull && b.key == null) || (b.key != null && b.key.equals(key))) return b.value; if(b.next == null) { b.setNext(new Node<>(key, value, null)); @@ -159,9 +163,9 @@ private int createBucket(int ix, K key, int value) { private int addToBucket(int ix, K key, int value) { Node b = buckets[ix]; + final boolean keyNull = key == null; while(true) { - - if(b.key.equals(key)) { + if((keyNull && b.key == null) || (b.key != null && b.key.equals(key))){ int tmp = b.getValue(); b.setValue(value); return tmp; From 782245c856f2ad401a29e30e1fa9c9b9c94f0d17 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 01:53:01 +0100 Subject: [PATCH 68/81] again --- .../org/apache/sysds/runtime/frame/data/columns/DDCArray.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java index 2505ed6c275..5342e26f3c1 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java @@ -139,7 +139,7 @@ public static Array compressToDDC(Array arr, int estimateUnique) { final int t = getTryThreshold(arr.getValueType(), s, arr.getInMemorySize()); // One pass algorithm... - final HashMapToInt rcd = new HashMapToInt(estimateUnique); + final HashMapToInt rcd = new HashMapToInt(estimateUnique == Integer.MAX_VALUE ? 16 : estimateUnique); // map should guarantee to be able to hold the distinct values. final AMapToData m = MapToFactory.create(s, Math.min(t, estimateUnique)); Integer id = 0; From 12c6d1a43dc975cc3f137a46176900f16656f2b8 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 02:17:09 +0100 Subject: [PATCH 69/81] resizing --- .../frame/data/columns/HashMapToInt.java | 119 +++++++++++++----- 1 file changed, 87 insertions(+), 32 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java index 7f377dd633a..e276f680947 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java @@ -34,6 +34,8 @@ public class HashMapToInt implements Map, Serializable, Cloneable static final float DEFAULT_LOAD_FACTOR = 0.75f; protected Node[] buckets; + + protected int nullV = -1; protected int size; public HashMapToInt(int capacity) { @@ -59,6 +61,8 @@ public boolean isEmpty() { @Override @SuppressWarnings({"unchecked"}) public boolean containsKey(Object key) { + if(key == null) + return nullV != -1; return getI((K) key) != -1; } @@ -85,22 +89,24 @@ public Integer get(Object key) { } public int getI(K key) { - final int ix = hash(key); - Node b = buckets[ix]; - final boolean keyNull = key == null; - if(b != null) { - do { - if((keyNull && b.key == null) || (b.key != null && b.key.equals(key))) - return b.value; + if(key == null) { + return nullV; + } + else { + final int ix = hash(key); + Node b = buckets[ix]; + if(b != null) { + do { + if(b.key.equals(key)) + return b.value; + } + while((b = b.next) != null); } - while((b = b.next) != null); + return -1; } - return -1; } public int hash(K key) { - if(key == null) - return 0; return Math.abs(key.hashCode()) % buckets.length; } @@ -123,23 +129,36 @@ public Integer putIfAbsent(K key, Integer value) { } public int putIfAbsentI(K key, int value) { - final int ix = hash(key); - Node b = buckets[ix]; - if(b == null) - return createBucket(ix, key, value); - else - return putIfAbsentBucket(ix, key, value); + + if(key == null) { + if(nullV == -1) { + size++; + nullV = value; + return -1; + } + else + return nullV; + } + else { + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucket(ix, key, value); + else + return putIfAbsentBucket(ix, key, value); + } + } private int putIfAbsentBucket(int ix, K key, int value) { Node b = buckets[ix]; - final boolean keyNull = key == null; while(true) { - if((keyNull && b.key == null) || (b.key != null && b.key.equals(key))) + if(b.key.equals(key)) return b.value; if(b.next == null) { b.setNext(new Node<>(key, value, null)); size++; + resize(); return -1; } b = b.next; @@ -147,12 +166,21 @@ private int putIfAbsentBucket(int ix, K key, int value) { } public int putI(K key, int value) { - final int ix = hash(key); - Node b = buckets[ix]; - if(b == null) - return createBucket(ix, key, value); - else - return addToBucket(ix, key, value); + if(key == null) { + int tmp = nullV; + nullV = value; + if(tmp != -1) + size++; + return tmp; + } + else { + final int ix = hash(key); + Node b = buckets[ix]; + if(b == null) + return createBucket(ix, key, value); + else + return addToBucket(ix, key, value); + } } private int createBucket(int ix, K key, int value) { @@ -163,9 +191,8 @@ private int createBucket(int ix, K key, int value) { private int addToBucket(int ix, K key, int value) { Node b = buckets[ix]; - final boolean keyNull = key == null; while(true) { - if((keyNull && b.key == null) || (b.key != null && b.key.equals(key))){ + if(b.key.equals(key)) { int tmp = b.getValue(); b.setValue(value); return tmp; @@ -173,12 +200,32 @@ private int addToBucket(int ix, K key, int value) { if(b.next == null) { b.setNext(new Node<>(key, value, null)); size++; + resize(); return -1; } b = b.next; } } + @SuppressWarnings({"unchecked"}) + private void resize() { + if(size > buckets.length * DEFAULT_LOAD_FACTOR) { + + Node[] tmp = (Node[]) new Node[buckets.length * 2]; + Node[] oldBuckets = buckets; + buckets = tmp; + + for(Node n : oldBuckets) { + if(n != null) + do { + put(n.key, n.value); + } + while((n = n.next) != null); + } + + } + } + @Override public Integer remove(Object key) { throw new UnsupportedOperationException("Unimplemented method 'remove'"); @@ -211,6 +258,8 @@ public Set> entrySet() { @Override public void forEach(BiConsumer action) { + if(nullV != -1) + action.accept(null, nullV); for(Node n : buckets) { if(n != null) { do { @@ -223,7 +272,7 @@ public void forEach(BiConsumer action) { @Override public String toString() { - StringBuilder sb = new StringBuilder(size()*3); + StringBuilder sb = new StringBuilder(size() * 3); this.forEach((k, v) -> { sb.append("(" + k + "→" + v + ")"); }); @@ -280,10 +329,16 @@ private final class EntryIterator implements Iterator> { int bucketId = 0; protected EntryIterator() { - for(; bucketId < buckets.length; bucketId++) { - if(buckets[bucketId] != null) { - next = buckets[bucketId]; - break; + + if(nullV != -1) { + next = new Node<>(null, nullV, null); + } + else { + for(; bucketId < buckets.length; bucketId++) { + if(buckets[bucketId] != null) { + next = buckets[bucketId]; + break; + } } } } From ebb355568357f5e020e04f2de004248a1f59a2dd Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 02:22:53 +0100 Subject: [PATCH 70/81] fix cast to Integer --- .../java/org/apache/sysds/runtime/frame/data/columns/Array.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java index 9cdd3012e7e..21a9cb396ca 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/Array.java @@ -970,7 +970,7 @@ protected boolean earlyAbortEstimateDistinct(int distinctFound, int samplesTaken && distinctFound * 100 >= samplesTaken * 60; // More than 60 % distinct } - protected int setAndAddToDict(HashMapToInt rcd, AMapToData m, int i, Integer id) { + protected int setAndAddToDict(HashMapToInt rcd, AMapToData m, int i, int id) { final T val = getInternal(i); final int v = rcd.putIfAbsentI(val, id); if(v == -1) { From a2bc526fa2c5bc91ceb5a0a7831dff766e818a53 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 02:36:20 +0100 Subject: [PATCH 71/81] inverse --- .../runtime/frame/data/columns/DDCArray.java | 12 +----------- .../frame/data/columns/HashMapToInt.java | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java index 5342e26f3c1..97e36d595bb 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java @@ -154,17 +154,7 @@ public static Array compressToDDC(Array arr, int estimateUnique) { final AMapToData md = m.resize(rcd.size()); // Allocate the correct dictionary output - final Array ar; - if(rcd.containsKey(null)) - ar = (Array) ArrayFactory.allocateOptional(arr.getValueType(), rcd.size()); - else - ar = (Array) ArrayFactory.allocate(arr.getValueType(), rcd.size()); - - // Set elements in the Dictionary array --- much smaller. - // This inverts the mapping such that the value - // is the index in the dictionary - for(Entry e : rcd.entrySet()) - ar.set(e.getValue(), e.getKey()); + final Array ar = rcd.inverse(arr.getValueType()); return new DDCArray<>(ar, md); } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java index e276f680947..7bcad1bf210 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java @@ -27,6 +27,8 @@ import java.util.Set; import java.util.function.BiConsumer; +import org.apache.sysds.common.Types.ValueType; + public class HashMapToInt implements Map, Serializable, Cloneable { private static final long serialVersionUID = 3624988207265L; @@ -270,6 +272,21 @@ public void forEach(BiConsumer action) { } } + @SuppressWarnings({"unchecked"}) + public Array inverse(ValueType t ) { + final Array ar; + + if(containsKey(null)) + ar = (Array) ArrayFactory.allocateOptional(t, size()); + else + ar = (Array) ArrayFactory.allocate(t, size()); + + forEach((k, v) -> { + ar.set(v, k); + }); + return ar; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(size() * 3); From d7ee209946884bba3b95d67d80a674f2070c230f Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 02:37:26 +0100 Subject: [PATCH 72/81] remove Integer.valueOf --- .../org/apache/sysds/runtime/frame/data/columns/DDCArray.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java index 97e36d595bb..6ffd2260917 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java @@ -22,7 +22,6 @@ import java.io.DataInput; import java.io.DataOutput; import java.io.IOException; -import java.util.Map.Entry; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; @@ -126,7 +125,6 @@ public static Array compressToDDC(Array arr) { * @param estimateUnique The estimated number of unique values * @return Either a compressed version or the original. */ - @SuppressWarnings("unchecked") public static Array compressToDDC(Array arr, int estimateUnique) { try { From b5fbce9d79407961bed7a78880af7ae7b22553d4 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 02:41:02 +0100 Subject: [PATCH 73/81] ... that was stupid --- .../org/apache/sysds/runtime/frame/data/columns/DDCArray.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java index 6ffd2260917..0e5fee82f5b 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java @@ -140,7 +140,7 @@ public static Array compressToDDC(Array arr, int estimateUnique) { final HashMapToInt rcd = new HashMapToInt(estimateUnique == Integer.MAX_VALUE ? 16 : estimateUnique); // map should guarantee to be able to hold the distinct values. final AMapToData m = MapToFactory.create(s, Math.min(t, estimateUnique)); - Integer id = 0; + int id = 0; for(int i = 0; i < s && id < t; i++) id = arr.setAndAddToDict(rcd, m, i, id); From 0b45f8df17812c8bd6538ba02e31d92d0275ae9a Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 13:54:19 +0100 Subject: [PATCH 74/81] fix resize to correctly reflect size --- .../sysds/runtime/frame/data/columns/HashMapToInt.java | 2 +- .../sysds/runtime/transform/encode/CompressedEncode.java | 8 +++++--- .../frame/transform/TransformCompressedTestSingleCol.java | 1 + 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java index 7bcad1bf210..9641010c20c 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java @@ -216,7 +216,7 @@ private void resize() { Node[] tmp = (Node[]) new Node[buckets.length * 2]; Node[] oldBuckets = buckets; buckets = tmp; - + size = nullV == -1 ? 0 : 1; for(Node n : oldBuckets) { if(n != null) do { diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index c2733812e9a..33647345eb8 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -91,7 +91,8 @@ private CompressedEncode(MultiColumnEncoder enc, FrameBlock in, int k) { this.enc = enc; this.in = in; this.k = k; - this.pool = k > 1 && CommonThreadPool.useParallelismOnThread() ? CommonThreadPool.get(k) : null; + // this.pool = k > 1 && CommonThreadPool.useParallelismOnThread() ? CommonThreadPool.get(k) : null; + this.pool = null; this.inputContainsCompressed = containsCompressed(in); } @@ -203,6 +204,7 @@ private int shiftGroups(List groups) { private AColGroup encode(ColumnEncoderComposite c) throws Exception { final Timing t = new Timing(); AColGroup g = executeEncode(c); + if(LOG.isDebugEnabled()) LOG.debug(String.format("Encode: columns: %4d estimateDistinct: %6d distinct: %6d size: %6d time: %10f", c._colID, c._estNumDistincts, g.getNumValues(), g.estimateInMemorySize(), t.stop())); @@ -240,7 +242,7 @@ private AColGroup recodeToDummy(ColumnEncoderComposite c) throws Exception { r.set(0, new ColumnEncoderRecode(colId, (HashMapToInt) map)); int domain = map.size(); if(containsNull && domain == 0) - return new ColGroupEmpty(ColIndexFactory.create(1)); + return new ColGroupEmpty(SINGLE_COL_TMP_INDEX); IColIndex colIndexes = ColIndexFactory.create(0, domain); if(domain == 1 && !containsNull) { nnz.addAndGet(in.getNumRows()); @@ -761,7 +763,7 @@ private final long putInto(List ucg, DenseBlock db, i private final long putIntoRowBlock(List ucg, int jl, int ju, long nnz, int i, final double[] rval, final int off) { for(int j = jl; j < ju; j++) { - nnz += (rval[off + j] = ucg.get(j).array.getAsNaNDouble(i)) == 0.0 ? 1 : 0; + nnz += (rval[off + j] = ucg.get(j).array.getAsNaNDouble(i)) == 0.0 ? 0 : 1; } return nnz; } diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java index d5b8a094154..e58e724849e 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java @@ -175,6 +175,7 @@ public void test(String spec) { MatrixBlock outMeta1 = ec.apply(data, k); + TestUtils.compareMatrices(outNormal, outMeta1, 0, "Not Equal after apply"); MultiColumnEncoder ec2 = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), From ea6ad186235fb9d418902a9bc9a9bfe5ad191147 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 13:54:49 +0100 Subject: [PATCH 75/81] remove imports --- .../sysds/runtime/transform/encode/CompressedEncode.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 33647345eb8..f14d886ac16 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -61,7 +61,6 @@ import org.apache.sysds.runtime.frame.data.compress.ArrayCompressionStatistics; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin.BinMethod; -import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.stats.Timing; @@ -92,7 +91,7 @@ private CompressedEncode(MultiColumnEncoder enc, FrameBlock in, int k) { this.in = in; this.k = k; // this.pool = k > 1 && CommonThreadPool.useParallelismOnThread() ? CommonThreadPool.get(k) : null; - this.pool = null; + this.pool = null; this.inputContainsCompressed = containsCompressed(in); } @@ -204,7 +203,6 @@ private int shiftGroups(List groups) { private AColGroup encode(ColumnEncoderComposite c) throws Exception { final Timing t = new Timing(); AColGroup g = executeEncode(c); - if(LOG.isDebugEnabled()) LOG.debug(String.format("Encode: columns: %4d estimateDistinct: %6d distinct: %6d size: %6d time: %10f", c._colID, c._estNumDistincts, g.getNumValues(), g.estimateInMemorySize(), t.stop())); From 114c1709f9f143b8a10033e4439a57bae16b8633 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 15:37:25 +0100 Subject: [PATCH 76/81] fix column offsets --- .../sysds/runtime/frame/data/columns/HashMapToInt.java | 2 +- .../sysds/runtime/transform/encode/ColumnEncoderBin.java | 3 +++ .../sysds/runtime/transform/encode/CompressedEncode.java | 5 +++-- .../frame/transform/TransformCompressedTestMultiCol.java | 3 ++- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java index 9641010c20c..fa6a86fce49 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/HashMapToInt.java @@ -216,7 +216,7 @@ private void resize() { Node[] tmp = (Node[]) new Node[buckets.length * 2]; Node[] oldBuckets = buckets; buckets = tmp; - size = nullV == -1 ? 0 : 1; + size = (nullV == -1) ? 0 : 1; for(Node n : oldBuckets) { if(n != null) do { diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java index 524a745a467..41cba2b9250 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java @@ -356,6 +356,9 @@ public Callable getPartialMergeBuildTask(HashMap ret) { } public void computeBins(double min, double max) { + if(min == max){ + _numBin = 1; + } // ensure allocated internal transformation metadata if(_binMins == null || _binMaxs == null) { _binMins = new double[_numBin]; diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index f14d886ac16..f88f737e6da 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -61,6 +61,7 @@ import org.apache.sysds.runtime.frame.data.compress.ArrayCompressionStatistics; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin.BinMethod; +import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.runtime.util.UtilFunctions; import org.apache.sysds.utils.stats.Timing; @@ -90,8 +91,7 @@ private CompressedEncode(MultiColumnEncoder enc, FrameBlock in, int k) { this.enc = enc; this.in = in; this.k = k; - // this.pool = k > 1 && CommonThreadPool.useParallelismOnThread() ? CommonThreadPool.get(k) : null; - this.pool = null; + this.pool = k > 1 && CommonThreadPool.useParallelismOnThread() ? CommonThreadPool.get(k) : null; this.inputContainsCompressed = containsCompressed(in); } @@ -190,6 +190,7 @@ private int shiftGroups(List groups) { AColGroup g = groups.get(curGroup); groups.set(curGroup, g.shiftColIndices(curCols)); curCols += g.getColIndices().size(); + curGroup++; } } if(ucCols.size() > 0) { diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java index 8094a59f48b..2b51a77b705 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java @@ -64,7 +64,7 @@ public static Collection data() { final int[] threads = new int[] {1, 4}; try { - ValueType[] kPlusCols = new ValueType[1002]; + ValueType[] kPlusCols = new ValueType[100]; Arrays.fill(kPlusCols, ValueType.BOOLEAN); @@ -167,6 +167,7 @@ public void test(String spec) { data.getNumColumns(), meta); MatrixBlock outNormal = encoderNormal.encode(data, k); + TestUtils.compareMatrices(outNormal, outCompressed, 0, "Not Equal after apply"); MultiColumnEncoder ec2 = EncoderFactory.createEncoder(spec, data.getColumnNames(), data.getNumColumns(), From fed33ecc22c6fda470fd55638982a72cea033f1a Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 15:54:57 +0100 Subject: [PATCH 77/81] correct character array memory --- .../sysds/runtime/frame/data/columns/CharArray.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java index a4192c6440f..69f2486678e 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/CharArray.java @@ -31,6 +31,7 @@ import org.apache.sysds.runtime.frame.data.lib.FrameUtil; import org.apache.sysds.runtime.matrix.data.Pair; import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.sysds.utils.MemoryEstimates; public class CharArray extends Array { @@ -372,6 +373,17 @@ public boolean possiblyContainsNaN() { return false; } + @Override + public long getInMemorySize() { + return estimateInMemorySize(_size); + } + + public static long estimateInMemorySize(int nRow) { + long size = baseMemoryCost(); // object header + object reference + size += MemoryEstimates.charArrayCost(nRow); + return size; + } + @Override public String toString() { StringBuilder sb = new StringBuilder(_size * 2 + 15); From f2756c6d62eeaf533e95cc893b3817cfc1601250 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 16:16:30 +0100 Subject: [PATCH 78/81] transform full perf --- .../sysds/performance/frame/Transform.java | 119 +++++++++--------- 1 file changed, 60 insertions(+), 59 deletions(-) diff --git a/src/test/java/org/apache/sysds/performance/frame/Transform.java b/src/test/java/org/apache/sysds/performance/frame/Transform.java index 27d4acad255..c63d31eea92 100644 --- a/src/test/java/org/apache/sysds/performance/frame/Transform.java +++ b/src/test/java/org/apache/sysds/performance/frame/Transform.java @@ -27,6 +27,7 @@ import org.apache.sysds.performance.generators.IGenerate; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.lib.FrameLibCompress; import org.apache.sysds.runtime.transform.encode.EncoderFactory; import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; import org.apache.sysds.test.TestUtils; @@ -48,7 +49,7 @@ public Transform(int N, IGenerate gen, int k, String spec) { } public void run() throws Exception { - // execute(() -> te(), () -> clear(), "Normal"); + execute(() -> te(), () -> clear(), "Normal"); execute(() -> tec(), () -> clear(), "Compressed"); } @@ -88,46 +89,46 @@ public static void main(String[] args) throws Exception { FrameBlock in; for(int i = 1; i < 1000; i *= 10) { + int rows = 100000 * i; + in = TestUtils.generateRandomFrameBlock(rows, new ValueType[] {ValueType.UINT4}, 32); - // in = TestUtils.generateRandomFrameBlock(100000 * i, new ValueType[] {ValueType.UINT4}, 32); + System.out.println("Without null"); + run(k, in); - // System.out.println("Without null"); - // run(k, in); + System.out.println("Compressed without null"); + in = FrameLibCompress.compress(in, k); + run(k, in); - // System.out.println("Compressed without null"); - // in = FrameLibCompress.compress(in, k); - // run(k, in); + in = TestUtils.generateRandomFrameBlock(rows, new ValueType[] {ValueType.UINT4}, 32, 0.5); - // in = TestUtils.generateRandomFrameBlock(100000 * i, new ValueType[] {ValueType.UINT4}, 32, 0.5); + System.out.println("With null"); - // System.out.println("With null"); - - // run(k, in); - // System.out.println("Compressed with null"); - // in = FrameLibCompress.compress(in, k); - // run(k, in); + run(k, in); + System.out.println("Compressed with null"); + in = FrameLibCompress.compress(in, k); + run(k, in); in = TestUtils.generateRandomFrameBlock( - 100000 * i, new ValueType[] {ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, + rows, new ValueType[] {ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4}, 32); System.out.println("10 col without null"); run10(k, in); - // System.out.println("10 col compressed without null"); - // in = FrameLibCompress.compress(in, k); - // run10(k, in); + System.out.println("10 col compressed without null"); + in = FrameLibCompress.compress(in, k); + run10(k, in); in = TestUtils.generateRandomFrameBlock( - 100000 * i, new ValueType[] {ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, + rows, new ValueType[] {ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4, ValueType.UINT4}, 32, 0.5); - // System.out.println("10 col with null"); - // run10(k, in); - // System.out.println("10 col Compressed with null"); - // in = FrameLibCompress.compress(in, k); - // run10(k, in); + System.out.println("10 col with null"); + run10(k, in); + System.out.println("10 col Compressed with null"); + in = FrameLibCompress.compress(in, k); + run10(k, in); } System.exit(0); // forcefully stop. @@ -135,46 +136,46 @@ public static void main(String[] args) throws Exception { private static void run10(int k, FrameBlock in) throws Exception { ConstFrame gen = new ConstFrame(in); - new Transform(300, gen, k, "{}").run(); - // new Transform(300, gen, k, "{ids:true, recode:[1,2,3,4,5,6,7,8,9,10]}").run(); - // new Transform(300, gen, k, "{ids:true, bin:[" // - // + "\n{id:1, method:equi-width, numbins:4}," // - // + "\n{id:2, method:equi-width, numbins:4}," // - // + "\n{id:3, method:equi-width, numbins:4}," // - // + "\n{id:4, method:equi-width, numbins:4}," // - // + "\n{id:5, method:equi-width, numbins:4}," // - // + "\n{id:6, method:equi-width, numbins:4}," // - // + "\n{id:7, method:equi-width, numbins:4}," // - // + "\n{id:8, method:equi-width, numbins:4}," // - // + "\n{id:9, method:equi-width, numbins:4}," // - // + "\n{id:10, method:equi-width, numbins:4}," // - // + "]}").run(); - // new Transform(300, gen, k, "{ids:true, bin:[" // - // + "\n{id:1, method:equi-width, numbins:4}," // - // + "\n{id:2, method:equi-width, numbins:4}," // - // + "\n{id:3, method:equi-width, numbins:4}," // - // + "\n{id:4, method:equi-width, numbins:4}," // - // + "\n{id:5, method:equi-width, numbins:4}," // - // + "\n{id:6, method:equi-width, numbins:4}," // - // + "\n{id:7, method:equi-width, numbins:4}," // - // + "\n{id:8, method:equi-width, numbins:4}," // - // + "\n{id:9, method:equi-width, numbins:4}," // - // + "\n{id:10, method:equi-width, numbins:4}," // - // + "], dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); - // new Transform(300, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10}").run(); - // new Transform(300, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10, dummycode:[1,2,3,4,5,6,7,8,9,10]}") - // .run(); + new Transform(30, gen, k, "{}").run(); + new Transform(30, gen, k, "{ids:true, recode:[1,2,3,4,5,6,7,8,9,10]}").run(); + new Transform(30, gen, k, "{ids:true, bin:[" // + + "\n{id:1, method:equi-width, numbins:4}," // + + "\n{id:2, method:equi-width, numbins:4}," // + + "\n{id:3, method:equi-width, numbins:4}," // + + "\n{id:4, method:equi-width, numbins:4}," // + + "\n{id:5, method:equi-width, numbins:4}," // + + "\n{id:6, method:equi-width, numbins:4}," // + + "\n{id:7, method:equi-width, numbins:4}," // + + "\n{id:8, method:equi-width, numbins:4}," // + + "\n{id:9, method:equi-width, numbins:4}," // + + "\n{id:10, method:equi-width, numbins:4}," // + + "]}").run(); + new Transform(30, gen, k, "{ids:true, bin:[" // + + "\n{id:1, method:equi-width, numbins:4}," // + + "\n{id:2, method:equi-width, numbins:4}," // + + "\n{id:3, method:equi-width, numbins:4}," // + + "\n{id:4, method:equi-width, numbins:4}," // + + "\n{id:5, method:equi-width, numbins:4}," // + + "\n{id:6, method:equi-width, numbins:4}," // + + "\n{id:7, method:equi-width, numbins:4}," // + + "\n{id:8, method:equi-width, numbins:4}," // + + "\n{id:9, method:equi-width, numbins:4}," // + + "\n{id:10, method:equi-width, numbins:4}," // + + "], dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); + new Transform(30, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10}").run(); + new Transform(30, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10, dummycode:[1,2,3,4,5,6,7,8,9,10]}") + .run(); } private static void run(int k, FrameBlock in) throws Exception { ConstFrame gen = new ConstFrame(in); // // passthrough - // new Transform(300, gen, k, "{}").run(); - // new Transform(300, gen, k, "{ids:true, recode:[1]}").run(); - // new Transform(300, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}").run(); - // new Transform(300, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[1]}").run(); - // new Transform(300, gen, k, "{ids:true, hash:[1], K:10}").run(); - // new Transform(300, gen, k, "{ids:true, hash:[1], K:10, dummycode:[1]}").run(); + new Transform(30, gen, k, "{}").run(); + new Transform(30, gen, k, "{ids:true, recode:[1]}").run(); + new Transform(30, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}").run(); + new Transform(30, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[1]}").run(); + new Transform(30, gen, k, "{ids:true, hash:[1], K:10}").run(); + new Transform(30, gen, k, "{ids:true, hash:[1], K:10, dummycode:[1]}").run(); } } From 0221af009ccc98d9320fa113a939ff17338c4730 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 16:29:11 +0100 Subject: [PATCH 79/81] reduce --- .../performance/compression/APerfTest.java | 14 +++++++- .../sysds/performance/frame/Transform.java | 32 +++++++++---------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/src/test/java/org/apache/sysds/performance/compression/APerfTest.java b/src/test/java/org/apache/sysds/performance/compression/APerfTest.java index 74114bf84e7..2c49393384f 100644 --- a/src/test/java/org/apache/sysds/performance/compression/APerfTest.java +++ b/src/test/java/org/apache/sysds/performance/compression/APerfTest.java @@ -36,10 +36,22 @@ public abstract class APerfTest { /** Default Repetitions */ protected final int N; + /** Warmup iterations */ + protected final int W; + protected APerfTest(int N, IGenerate gen) { ret = new ArrayList<>(N); this.gen = gen; this.N = N; + this.W = 10; + } + + + protected APerfTest(int N, int W, IGenerate gen) { + ret = new ArrayList<>(N); + this.gen = gen; + this.N = N; + this.W = 10; } protected void execute(F f, String name) throws InterruptedException { @@ -53,7 +65,7 @@ protected void execute(F f, F c, String name) throws InterruptedException { } protected void execute(F f, F c, F b, String name) throws InterruptedException { - warmup(f, 10); + warmup(f, W); gen.generate(N); ret.clear(); double[] times = TimingUtils.time(f, c, b, N, gen); diff --git a/src/test/java/org/apache/sysds/performance/frame/Transform.java b/src/test/java/org/apache/sysds/performance/frame/Transform.java index c63d31eea92..78740d92cba 100644 --- a/src/test/java/org/apache/sysds/performance/frame/Transform.java +++ b/src/test/java/org/apache/sysds/performance/frame/Transform.java @@ -39,7 +39,7 @@ public class Transform extends APerfTest { private final String spec; public Transform(int N, IGenerate gen, int k, String spec) { - super(N, gen); + super(N,2, gen); this.k = k; this.spec = spec; FrameBlock in = gen.take(); @@ -88,8 +88,8 @@ public static void main(String[] args) throws Exception { int k = InfrastructureAnalyzer.getLocalParallelism(); FrameBlock in; - for(int i = 1; i < 1000; i *= 10) { - int rows = 100000 * i; + // for(int i = 1; i < 1000; i *= 10) { + int rows = 100000 * 100; in = TestUtils.generateRandomFrameBlock(rows, new ValueType[] {ValueType.UINT4}, 32); System.out.println("Without null"); @@ -129,16 +129,16 @@ public static void main(String[] args) throws Exception { System.out.println("10 col Compressed with null"); in = FrameLibCompress.compress(in, k); run10(k, in); - } + // } System.exit(0); // forcefully stop. } private static void run10(int k, FrameBlock in) throws Exception { ConstFrame gen = new ConstFrame(in); - new Transform(30, gen, k, "{}").run(); - new Transform(30, gen, k, "{ids:true, recode:[1,2,3,4,5,6,7,8,9,10]}").run(); - new Transform(30, gen, k, "{ids:true, bin:[" // + new Transform(10, gen, k, "{}").run(); + new Transform(10, gen, k, "{ids:true, recode:[1,2,3,4,5,6,7,8,9,10]}").run(); + new Transform(10, gen, k, "{ids:true, bin:[" // + "\n{id:1, method:equi-width, numbins:4}," // + "\n{id:2, method:equi-width, numbins:4}," // + "\n{id:3, method:equi-width, numbins:4}," // @@ -150,7 +150,7 @@ private static void run10(int k, FrameBlock in) throws Exception { + "\n{id:9, method:equi-width, numbins:4}," // + "\n{id:10, method:equi-width, numbins:4}," // + "]}").run(); - new Transform(30, gen, k, "{ids:true, bin:[" // + new Transform(10, gen, k, "{ids:true, bin:[" // + "\n{id:1, method:equi-width, numbins:4}," // + "\n{id:2, method:equi-width, numbins:4}," // + "\n{id:3, method:equi-width, numbins:4}," // @@ -162,20 +162,20 @@ private static void run10(int k, FrameBlock in) throws Exception { + "\n{id:9, method:equi-width, numbins:4}," // + "\n{id:10, method:equi-width, numbins:4}," // + "], dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); - new Transform(30, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10}").run(); - new Transform(30, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10, dummycode:[1,2,3,4,5,6,7,8,9,10]}") + new Transform(10, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10}").run(); + new Transform(10, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10, dummycode:[1,2,3,4,5,6,7,8,9,10]}") .run(); } private static void run(int k, FrameBlock in) throws Exception { ConstFrame gen = new ConstFrame(in); // // passthrough - new Transform(30, gen, k, "{}").run(); - new Transform(30, gen, k, "{ids:true, recode:[1]}").run(); - new Transform(30, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}").run(); - new Transform(30, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[1]}").run(); - new Transform(30, gen, k, "{ids:true, hash:[1], K:10}").run(); - new Transform(30, gen, k, "{ids:true, hash:[1], K:10, dummycode:[1]}").run(); + new Transform(10, gen, k, "{}").run(); + new Transform(10, gen, k, "{ids:true, recode:[1]}").run(); + new Transform(10, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}").run(); + new Transform(10, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[1]}").run(); + new Transform(10, gen, k, "{ids:true, hash:[1], K:10}").run(); + new Transform(10, gen, k, "{ids:true, hash:[1], K:10, dummycode:[1]}").run(); } } From 3fe379a0bc623c871d5c539d79982c29d4a792d2 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 17:01:54 +0100 Subject: [PATCH 80/81] repeat passthrough --- .../sysds/performance/frame/Transform.java | 68 +++++++++---------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/src/test/java/org/apache/sysds/performance/frame/Transform.java b/src/test/java/org/apache/sysds/performance/frame/Transform.java index 78740d92cba..9fd188ac950 100644 --- a/src/test/java/org/apache/sysds/performance/frame/Transform.java +++ b/src/test/java/org/apache/sysds/performance/frame/Transform.java @@ -136,46 +136,46 @@ public static void main(String[] args) throws Exception { private static void run10(int k, FrameBlock in) throws Exception { ConstFrame gen = new ConstFrame(in); - new Transform(10, gen, k, "{}").run(); - new Transform(10, gen, k, "{ids:true, recode:[1,2,3,4,5,6,7,8,9,10]}").run(); - new Transform(10, gen, k, "{ids:true, bin:[" // - + "\n{id:1, method:equi-width, numbins:4}," // - + "\n{id:2, method:equi-width, numbins:4}," // - + "\n{id:3, method:equi-width, numbins:4}," // - + "\n{id:4, method:equi-width, numbins:4}," // - + "\n{id:5, method:equi-width, numbins:4}," // - + "\n{id:6, method:equi-width, numbins:4}," // - + "\n{id:7, method:equi-width, numbins:4}," // - + "\n{id:8, method:equi-width, numbins:4}," // - + "\n{id:9, method:equi-width, numbins:4}," // - + "\n{id:10, method:equi-width, numbins:4}," // - + "]}").run(); - new Transform(10, gen, k, "{ids:true, bin:[" // - + "\n{id:1, method:equi-width, numbins:4}," // - + "\n{id:2, method:equi-width, numbins:4}," // - + "\n{id:3, method:equi-width, numbins:4}," // - + "\n{id:4, method:equi-width, numbins:4}," // - + "\n{id:5, method:equi-width, numbins:4}," // - + "\n{id:6, method:equi-width, numbins:4}," // - + "\n{id:7, method:equi-width, numbins:4}," // - + "\n{id:8, method:equi-width, numbins:4}," // - + "\n{id:9, method:equi-width, numbins:4}," // - + "\n{id:10, method:equi-width, numbins:4}," // - + "], dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); - new Transform(10, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10}").run(); - new Transform(10, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10, dummycode:[1,2,3,4,5,6,7,8,9,10]}") - .run(); + new Transform(20, gen, k, "{}").run(); + // new Transform(10, gen, k, "{ids:true, recode:[1,2,3,4,5,6,7,8,9,10]}").run(); + // new Transform(10, gen, k, "{ids:true, bin:[" // + // + "\n{id:1, method:equi-width, numbins:4}," // + // + "\n{id:2, method:equi-width, numbins:4}," // + // + "\n{id:3, method:equi-width, numbins:4}," // + // + "\n{id:4, method:equi-width, numbins:4}," // + // + "\n{id:5, method:equi-width, numbins:4}," // + // + "\n{id:6, method:equi-width, numbins:4}," // + // + "\n{id:7, method:equi-width, numbins:4}," // + // + "\n{id:8, method:equi-width, numbins:4}," // + // + "\n{id:9, method:equi-width, numbins:4}," // + // + "\n{id:10, method:equi-width, numbins:4}," // + // + "]}").run(); + // new Transform(10, gen, k, "{ids:true, bin:[" // + // + "\n{id:1, method:equi-width, numbins:4}," // + // + "\n{id:2, method:equi-width, numbins:4}," // + // + "\n{id:3, method:equi-width, numbins:4}," // + // + "\n{id:4, method:equi-width, numbins:4}," // + // + "\n{id:5, method:equi-width, numbins:4}," // + // + "\n{id:6, method:equi-width, numbins:4}," // + // + "\n{id:7, method:equi-width, numbins:4}," // + // + "\n{id:8, method:equi-width, numbins:4}," // + // + "\n{id:9, method:equi-width, numbins:4}," // + // + "\n{id:10, method:equi-width, numbins:4}," // + // + "], dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); + // new Transform(10, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10}").run(); + // new Transform(10, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10, dummycode:[1,2,3,4,5,6,7,8,9,10]}") + // .run(); } private static void run(int k, FrameBlock in) throws Exception { ConstFrame gen = new ConstFrame(in); // // passthrough new Transform(10, gen, k, "{}").run(); - new Transform(10, gen, k, "{ids:true, recode:[1]}").run(); - new Transform(10, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}").run(); - new Transform(10, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[1]}").run(); - new Transform(10, gen, k, "{ids:true, hash:[1], K:10}").run(); - new Transform(10, gen, k, "{ids:true, hash:[1], K:10, dummycode:[1]}").run(); + // new Transform(10, gen, k, "{ids:true, recode:[1]}").run(); + // new Transform(10, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}").run(); + // new Transform(10, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[1]}").run(); + // new Transform(10, gen, k, "{ids:true, hash:[1], K:10}").run(); + // new Transform(10, gen, k, "{ids:true, hash:[1], K:10, dummycode:[1]}").run(); } } From 3487c8835a28837ae1a1e694aac8bc070ad03d83 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Wed, 15 Jan 2025 17:15:25 +0100 Subject: [PATCH 81/81] transform revert to cover all --- .../sysds/performance/frame/Transform.java | 66 +++++++++---------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/src/test/java/org/apache/sysds/performance/frame/Transform.java b/src/test/java/org/apache/sysds/performance/frame/Transform.java index 9fd188ac950..2825790bf90 100644 --- a/src/test/java/org/apache/sysds/performance/frame/Transform.java +++ b/src/test/java/org/apache/sysds/performance/frame/Transform.java @@ -137,45 +137,45 @@ public static void main(String[] args) throws Exception { private static void run10(int k, FrameBlock in) throws Exception { ConstFrame gen = new ConstFrame(in); new Transform(20, gen, k, "{}").run(); - // new Transform(10, gen, k, "{ids:true, recode:[1,2,3,4,5,6,7,8,9,10]}").run(); - // new Transform(10, gen, k, "{ids:true, bin:[" // - // + "\n{id:1, method:equi-width, numbins:4}," // - // + "\n{id:2, method:equi-width, numbins:4}," // - // + "\n{id:3, method:equi-width, numbins:4}," // - // + "\n{id:4, method:equi-width, numbins:4}," // - // + "\n{id:5, method:equi-width, numbins:4}," // - // + "\n{id:6, method:equi-width, numbins:4}," // - // + "\n{id:7, method:equi-width, numbins:4}," // - // + "\n{id:8, method:equi-width, numbins:4}," // - // + "\n{id:9, method:equi-width, numbins:4}," // - // + "\n{id:10, method:equi-width, numbins:4}," // - // + "]}").run(); - // new Transform(10, gen, k, "{ids:true, bin:[" // - // + "\n{id:1, method:equi-width, numbins:4}," // - // + "\n{id:2, method:equi-width, numbins:4}," // - // + "\n{id:3, method:equi-width, numbins:4}," // - // + "\n{id:4, method:equi-width, numbins:4}," // - // + "\n{id:5, method:equi-width, numbins:4}," // - // + "\n{id:6, method:equi-width, numbins:4}," // - // + "\n{id:7, method:equi-width, numbins:4}," // - // + "\n{id:8, method:equi-width, numbins:4}," // - // + "\n{id:9, method:equi-width, numbins:4}," // - // + "\n{id:10, method:equi-width, numbins:4}," // - // + "], dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); - // new Transform(10, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10}").run(); - // new Transform(10, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10, dummycode:[1,2,3,4,5,6,7,8,9,10]}") - // .run(); + new Transform(10, gen, k, "{ids:true, recode:[1,2,3,4,5,6,7,8,9,10]}").run(); + new Transform(10, gen, k, "{ids:true, bin:[" // + + "\n{id:1, method:equi-width, numbins:4}," // + + "\n{id:2, method:equi-width, numbins:4}," // + + "\n{id:3, method:equi-width, numbins:4}," // + + "\n{id:4, method:equi-width, numbins:4}," // + + "\n{id:5, method:equi-width, numbins:4}," // + + "\n{id:6, method:equi-width, numbins:4}," // + + "\n{id:7, method:equi-width, numbins:4}," // + + "\n{id:8, method:equi-width, numbins:4}," // + + "\n{id:9, method:equi-width, numbins:4}," // + + "\n{id:10, method:equi-width, numbins:4}," // + + "]}").run(); + new Transform(10, gen, k, "{ids:true, bin:[" // + + "\n{id:1, method:equi-width, numbins:4}," // + + "\n{id:2, method:equi-width, numbins:4}," // + + "\n{id:3, method:equi-width, numbins:4}," // + + "\n{id:4, method:equi-width, numbins:4}," // + + "\n{id:5, method:equi-width, numbins:4}," // + + "\n{id:6, method:equi-width, numbins:4}," // + + "\n{id:7, method:equi-width, numbins:4}," // + + "\n{id:8, method:equi-width, numbins:4}," // + + "\n{id:9, method:equi-width, numbins:4}," // + + "\n{id:10, method:equi-width, numbins:4}," // + + "], dummycode:[1,2,3,4,5,6,7,8,9,10]}").run(); + new Transform(10, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10}").run(); + new Transform(10, gen, k, "{ids:true, hash:[1,2,3,4,5,6,7,8,9,10], K:10, dummycode:[1,2,3,4,5,6,7,8,9,10]}") + .run(); } private static void run(int k, FrameBlock in) throws Exception { ConstFrame gen = new ConstFrame(in); // // passthrough new Transform(10, gen, k, "{}").run(); - // new Transform(10, gen, k, "{ids:true, recode:[1]}").run(); - // new Transform(10, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}").run(); - // new Transform(10, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[1]}").run(); - // new Transform(10, gen, k, "{ids:true, hash:[1], K:10}").run(); - // new Transform(10, gen, k, "{ids:true, hash:[1], K:10, dummycode:[1]}").run(); + new Transform(10, gen, k, "{ids:true, recode:[1]}").run(); + new Transform(10, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}").run(); + new Transform(10, gen, k, "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[1]}").run(); + new Transform(10, gen, k, "{ids:true, hash:[1], K:10}").run(); + new Transform(10, gen, k, "{ids:true, hash:[1], K:10, dummycode:[1]}").run(); } }