diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index 1fc582924e4..8156a98cd35 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -607,17 +607,25 @@ public double getAsNaNDouble(int i) { private static double getAsDouble(String s) { try { - return DoubleArray.parseDouble(s); } catch(Exception e) { - String ls = s.toLowerCase(); - if(ls.equals("true") || ls.equals("t")) + // Fallback for boolean-like tokens. Dispatch on length first so non-boolean strings are + // rejected immediately, and avoid allocating a lower-cased copy by comparing case-insensitively + // (single char compare for the 1-char tokens). + final int len = s.length(); + if(len == 1) { + final char c = s.charAt(0); + if(c == 't' || c == 'T') + return 1; + else if(c == 'f' || c == 'F') + return 0; + } + else if(len == 4 && s.compareToIgnoreCase("true") == 0) return 1; - else if(ls.equals("false") || ls.equals("f")) + else if(len == 5 && s.compareToIgnoreCase("false") == 0) return 0; - else - throw new DMLRuntimeException("Unable to change to double: " + s, e); + throw new DMLRuntimeException("Unable to change to double: " + s, e); } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java index 724af1be630..70834675ded 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java @@ -23,6 +23,10 @@ import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -30,6 +34,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.CommonThreadPool; /** * Base class for all transform decoders providing both a row and block @@ -77,8 +82,31 @@ public String[] getColnames() { * @param k Parallelization degree * @return returns the given output frame block for convenience */ - public FrameBlock decode(MatrixBlock in, FrameBlock out, int k) { - return decode(in, out); + public FrameBlock decode(final MatrixBlock in, final FrameBlock out, final int k) { + if(k <= 1) + return decode(in, out); + final ExecutorService pool = CommonThreadPool.get(k); + out.ensureAllocatedColumns(in.getNumRows()); + try { + final List> tasks = new ArrayList<>(); + int blz = Math.max((in.getNumRows() + k) / k, 1000); + + for(int i = 0; i < in.getNumRows(); i += blz){ + final int start = i; + final int end = Math.min(in.getNumRows(), i + blz); + tasks.add(pool.submit(() -> decode(in, out, start, end))); + } + + for(Future f : tasks) + f.get(); + return out; + } + catch(Exception e) { + throw new RuntimeException(e); + } + finally { + pool.shutdown(); + } } /** diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java index edee095f612..79d9b7f3a40 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java @@ -28,6 +28,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.UtilFunctions; @@ -43,6 +44,8 @@ public class DecoderBin extends Decoder { // a) column bin boundaries private int[] _numBins; + private int[] _dcCols = null; + private int[] _srcCols = null; private double[][] _binMins = null; private double[][] _binMaxs = null; @@ -50,8 +53,9 @@ public DecoderBin() { super(null, null); } - protected DecoderBin(ValueType[] schema, int[] binCols) { + protected DecoderBin(ValueType[] schema, int[] binCols, int[] dcCols) { super(schema, binCols); + _dcCols = dcCols; } @Override @@ -66,14 +70,28 @@ public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { for( int i=rl; i< ru; i++ ) { for( int j=0; j<_colList.length; j++ ) { final Array a = out.getColumn(_colList[j] - 1); - final double val = in.get(i, _colList[j] - 1); + final double val = in.get(i, _srcCols[j] - 1); if(!Double.isNaN(val)){ - final int key = (int) Math.round(val); - double bmin = _binMins[j][key - 1]; - double bmax = _binMaxs[j][key - 1]; - double oval = bmin + (bmax - bmin) / 2 // bin center - + (val - key) * (bmax - bmin); // bin fractions - a.set(i, oval); + try{ + + final int key = (int) Math.round(val); + if(key == 0){ + a.set(i, _binMins[j][key]); + } + else{ + double bmin = _binMins[j][key - 1]; + double bmax = _binMaxs[j][key - 1]; + double oval = bmin + (bmax - bmin) / 2 // bin center + + (val - key) * (bmax - bmin); // bin fractions + a.set(i, oval); + } + } + catch(Exception e){ + LOG.error(a); + LOG.error(in.slice(0, in.getNumRows()-1, _colList[j]-1,_colList[j]-1)); + LOG.error( val); + throw e; + } } else a.set(i, val); // NaN @@ -111,6 +129,34 @@ public void initMetaData(FrameBlock meta) { _binMaxs[j][i] = Double.parseDouble(parts[1]); } } + + + if( _dcCols.length > 0 ) { + //prepare source column id mapping w/ dummy coding + _srcCols = new int[_colList.length]; + int ix1 = 0, ix2 = 0, off = 0; + while( ix1<_colList.length ) { + if( ix2>=_dcCols.length || _colList[ix1] < _dcCols[ix2] ) { + _srcCols[ix1] = _colList[ix1] + off; + ix1 ++; + } + else { //_colList[ix1] > _dcCols[ix2] + ColumnMetadata d =meta.getColumnMetadata()[_dcCols[ix2]-1]; + String v = meta.getString(0, _dcCols[ix2]-1); + if(v.length() > 1 && v.charAt(0) == '¿'){ + off += UtilFunctions.parseToLong(v.substring(1)) -1; + } + else { + off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + } + ix2 ++; + } + } + } + else { + //prepare direct source column mapping + _srcCols = _colList; + } } @Override diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java index f4bc9f8b216..adfef7bbc6d 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java @@ -25,13 +25,10 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.util.CommonThreadPool; /** * Simple composite decoder that applies a list of decoders @@ -59,33 +56,6 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { return out; } - - @Override - public FrameBlock decode(final MatrixBlock in, final FrameBlock out, final int k) { - final ExecutorService pool = CommonThreadPool.get(k); - out.ensureAllocatedColumns(in.getNumRows()); - try { - final List> tasks = new ArrayList<>(); - int blz = Math.max(in.getNumRows() / k, 1000); - for(Decoder decoder : _decoders){ - for(int i = 0; i < in.getNumRows(); i += blz){ - final int start = i; - final int end = Math.min(in.getNumRows(), i + blz); - tasks.add(pool.submit(() -> decoder.decode(in, out, start, end))); - } - } - for(Future f : tasks) - f.get(); - return out; - } - catch(Exception e) { - throw new RuntimeException(e); - } - finally { - pool.shutdown(); - } - } - @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru){ for( Decoder decoder : _decoders ) diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java index 0c4c6b42690..95d7f4fa4c9 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java @@ -27,31 +27,30 @@ import java.util.List; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.UtilFunctions; /** - * Simple atomic decoder for dummycoded columns. This decoder builds internally - * inverted column mappings from the given frame meta data. - * + * Simple atomic decoder for dummycoded columns. This decoder builds internally inverted column mappings from the given + * frame meta data. + * */ -public class DecoderDummycode extends Decoder -{ +public class DecoderDummycode extends Decoder { private static final long serialVersionUID = 4758831042891032129L; - + private int[] _clPos = null; private int[] _cuPos = null; - + protected DecoderDummycode(ValueType[] schema, int[] dcCols) { - //dcCols refers to column IDs in output (non-dc) + // dcCols refers to column IDs in output (non-dc) super(schema, dcCols); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { - //TODO perf (exploit sparse representation for better asymptotic behavior) out.ensureAllocatedColumns(in.getNumRows()); decode(in, out, 0, in.getNumRows()); return out; @@ -59,59 +58,100 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { - //TODO perf (exploit sparse representation for better asymptotic behavior) - // out.ensureAllocatedColumns(in.getNumRows()); - for( int i=rl; i= low && aix[h] < high) { + int k = aix[h]; + int col = _colList[j] - 1; + out.getColumn(col).set(i, k - low + 1); + } + // limit the binary search. + apos = h; + } + + } + @Override public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { List dcList = new ArrayList<>(); List clPosList = new ArrayList<>(); List cuPosList = new ArrayList<>(); - + // get the column IDs for the sub range of the dummycode columns and their destination positions, // where they will be decoded to - for( int j=0; j<_colList.length; j++ ) { + for(int j = 0; j < _colList.length; j++) { int colID = _colList[j]; - if (colID >= colStart && colID < colEnd) { + if(colID >= colStart && colID < colEnd) { dcList.add(colID - (colStart - 1)); clPosList.add(_clPos[j] - dummycodedOffset); cuPosList.add(_cuPos[j] - dummycodedOffset); } } - if (dcList.isEmpty()) + if(dcList.isEmpty()) return null; // create sub-range decoder int[] colList = dcList.stream().mapToInt(i -> i).toArray(); - DecoderDummycode subRangeDecoder = new DecoderDummycode( - Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), colList); + DecoderDummycode subRangeDecoder = new DecoderDummycode(Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), + colList); subRangeDecoder._clPos = clPosList.stream().mapToInt(i -> i).toArray(); subRangeDecoder._cuPos = cuPosList.stream().mapToInt(i -> i).toArray(); return subRangeDecoder; } - + @Override public void updateIndexRanges(long[] beginDims, long[] endDims) { if(_colList == null) return; - + long lowerColDest = beginDims[1]; long upperColDest = endDims[1]; for(int i = 0; i < _colList.length; i++) { long numDistinct = _cuPos[i] - _clPos[i]; - + if(_cuPos[i] <= beginDims[1] + 1) if(numDistinct > 0) lowerColDest -= numDistinct - 1; - + if(_cuPos[i] <= endDims[1] + 1) if(numDistinct > 0) upperColDest -= numDistinct - 1; @@ -119,16 +159,25 @@ public void updateIndexRanges(long[] beginDims, long[] endDims) { beginDims[1] = lowerColDest; endDims[1] = upperColDest; } - + @Override public void initMetaData(FrameBlock meta) { - _clPos = new int[_colList.length]; //col lower pos - _cuPos = new int[_colList.length]; //col upper pos - for( int j=0, off=0; j<_colList.length; j++ ) { + _clPos = new int[_colList.length]; // col lower pos + _cuPos = new int[_colList.length]; // col upper pos + for(int j = 0, off = 0; j < _colList.length; j++) { int colID = _colList[j]; - ColumnMetadata d = meta.getColumnMetadata()[colID-1]; - int ndist = d.isDefault() ? 0 : (int)d.getNumDistinct(); - ndist = ndist < -1 ? 0: ndist; + ColumnMetadata d = meta.getColumnMetadata()[colID - 1]; + String v = meta.getString(0, colID - 1); + int ndist; + if(v.length() > 1 && v.charAt(0) == '¿') { + ndist = UtilFunctions.parseToInt(v.substring(1)); + } + else { + ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); + } + + ndist = ndist < -1 ? 0 : ndist; // safety if all values was null. + _clPos[j] = off + colID; _cuPos[j] = _clPos[j] + ndist; off += ndist - 1; diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java index 0a400e6da92..12ba2968877 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java @@ -64,34 +64,52 @@ public static Decoder createDecoder(String spec, String[] colnames, ValueType[] try { //parse transform specification JSONObject jSpec = new JSONObject(spec); - List ldecoders = new ArrayList<>(); - //create decoders 'bin', 'recode', 'dummy' and 'pass-through' + //create decoders 'bin', 'recode', 'hash', 'dummy', and 'pass-through' List binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol); List rcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.RECODE.toString(), minCol, maxCol))); + List hcIDs = Arrays.asList(ArrayUtils.toObject( + TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.HASH.toString(), minCol, maxCol))); List dcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol))); + // only specially treat the columns with both recode and dictionary rcIDs = unionDistinct(rcIDs, dcIDs); + // remove hash recoded. // todo potentially wrong and remove? + rcIDs = except(rcIDs, hcIDs); + int len = dcIDs.isEmpty() ? Math.min(meta.getNumColumns(), clen) : meta.getNumColumns(); - List ptIDs = except(except(UtilFunctions.getSeqList(1, len, 1), rcIDs), binIDs); - + + // set the remaining columns to passthrough. + List ptIDs = UtilFunctions.getSeqList(1, len, 1); + // except recoded columns + ptIDs = except(ptIDs, rcIDs); + // binned columns + ptIDs = except(ptIDs, binIDs); + // hashed columns + ptIDs = except(ptIDs, hcIDs); // remove hashed columns + //create default schema if unspecified (with double columns for pass-through) if( schema == null ) { schema = UtilFunctions.nCopies(len, ValueType.STRING); for( Integer col : ptIDs ) schema[col-1] = ValueType.FP64; } + + // collect all the decoders in one list. + List ldecoders = new ArrayList<>(); if( !binIDs.isEmpty() ) { ldecoders.add(new DecoderBin(schema, - ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0])))); + ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0])), + ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); } if( !dcIDs.isEmpty() ) { ldecoders.add(new DecoderDummycode(schema, ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); } if( !rcIDs.isEmpty() ) { + // todo figure out if we need to handle rc columns with regards to dictionary offsets. ldecoders.add(new DecoderRecode(schema, !dcIDs.isEmpty(), ArrayUtils.toPrimitive(rcIDs.toArray(new Integer[0])))); } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java index 5b6bf7a093e..d2e7d59e81f 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java @@ -61,13 +61,12 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { int clen = Math.min(_colList.length, out.getNumColumns()); - for( int i=rl; i _dcCols[ix2] ColumnMetadata d =meta.getColumnMetadata()[_dcCols[ix2]-1]; - off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + String v = meta.getString( 0,_dcCols[ix2]-1); + if(v.length() > 1 && v.charAt(0) == '¿'){ + off += UtilFunctions.parseToLong(v.substring(1)) -1; + } + else { + off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + } ix2 ++; } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java index 33459a1c4f9..a48759493fa 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java @@ -29,6 +29,7 @@ import java.util.Map.Entry; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.Pair; @@ -46,7 +47,6 @@ public class DecoderRecode extends Decoder private static final long serialVersionUID = -3784249774608228805L; private HashMap[] _rcMaps = null; - private Object[][] _rcMapsDirect = null; private boolean _onOut = false; public DecoderRecode() { @@ -59,8 +59,7 @@ protected DecoderRecode(ValueType[] schema, boolean onOut, int[] rcCols) { } public Object getRcMapValue(int i, long key) { - return (_rcMapsDirect != null && key > 0) ? - _rcMapsDirect[i][(int)key-1] : _rcMaps[i].get(key); + return _rcMaps[i].get(key); } @Override @@ -129,27 +128,33 @@ public void initMetaData(FrameBlock meta) { for( int j=0; j<_colList.length; j++ ) { HashMap map = new HashMap<>(); for( int i=0; i v < Integer.MAX_VALUE) ) { - _rcMapsDirect = new Object[_rcMaps.length][]; - for( int i=0; i<_rcMaps.length; i++ ) { - Object[] arr = new Object[(int)max[i]]; - for(Entry e1 : _rcMaps[i].entrySet()) - arr[e1.getKey().intValue()-1] = e1.getValue(); - _rcMapsDirect[i] = arr; - } - } + // if( Arrays.stream(max).allMatch(v -> v < Integer.MAX_VALUE) ) { + // _rcMapsDirect = new Object[_rcMaps.length][]; + // for( int i=0; i<_rcMaps.length; i++ ) { + // Object[] arr = new Object[(int)max[i]]; + // for(Entry e1 : _rcMaps[i].entrySet()) + // arr[e1.getKey().intValue()-1] = e1.getValue(); + // _rcMapsDirect[i] = arr; + // } + // } } /** diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java index 400b7f64ffc..361c9c52135 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java @@ -146,7 +146,9 @@ public FrameBlock getMetaData(FrameBlock meta) { return meta; meta.ensureAllocatedColumns(1); - meta.set(0, _colID - 1, String.valueOf(_K)); + // set metadata of hash columns to magical hash value + k + meta.set(0, _colID - 1, String.format("¿%d" , _K)); + return meta; } @@ -154,7 +156,7 @@ public FrameBlock getMetaData(FrameBlock meta) { public void initMetaData(FrameBlock meta) { if(meta == null || meta.getNumRows() <= 0) return; - _K = UtilFunctions.parseToLong(meta.get(0, _colID - 1).toString()); + _K = UtilFunctions.parseToLong(meta.getString(0, _colID - 1).substring(1)); } @Override diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java new file mode 100644 index 00000000000..9e8b55df29b --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame.transform; + +import static org.junit.Assert.fail; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.transform.decode.Decoder; +import org.apache.sysds.runtime.transform.decode.DecoderFactory; +import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; +import org.apache.sysds.test.TestUtils; +import org.junit.Before; +import org.junit.Test; + +/** + * Exact inverse correctness tests for the transform decoders. Recode and dummycode are lossless category encodings, so a + * decode of the encoded matrix must reconstruct the original categorical frame. These tests assert exact reconstruction + * for the dense path, the sparse path, and the parallel path so that the dummycode sparse binary search and the parallel + * block split are validated against ground truth rather than only against each other. + */ +public class TransformDecodeRoundTripTest { + protected static final Log LOG = LogFactory.getLog(TransformDecodeRoundTripTest.class.getName()); + + @Before + public void setUp() { + // name must contain "main" so the parallel decode path reuses the shared thread pool + Thread.currentThread().setName("main_test_decode"); + } + + private static FrameBlock categoricalFrame() { + final String[] values = new String[] { + "apple", "banana", "apple", "cherry", "banana", "date", "apple", "cherry", "date", "banana", "elderberry", + "apple", "fig", "banana", "cherry", "apple", "date", "fig", "elderberry", "banana"}; + final FrameBlock f = new FrameBlock(new ValueType[] {ValueType.STRING}); + f.ensureAllocatedColumns(values.length); + for(int i = 0; i < values.length; i++) + f.set(i, 0, values[i]); + return f; + } + + @Test + public void recodeReconstructsOriginalDense() { + roundTrip("{ids:true, recode:[1]}", false, 1); + } + + @Test + public void recodeReconstructsOriginalSparse() { + roundTrip("{ids:true, recode:[1]}", true, 1); + } + + @Test + public void recodeReconstructsOriginalParallel() { + roundTrip("{ids:true, recode:[1]}", false, 4); + } + + @Test + public void dummycodeReconstructsOriginalDense() { + roundTrip("{ids:true, recode:[1], dummycode:[1]}", false, 1); + } + + @Test + public void dummycodeReconstructsOriginalSparse() { + // the one-hot encoded matrix is sparse, so this drives the dummycode sparse binary-search decode path + roundTrip("{ids:true, recode:[1], dummycode:[1]}", true, 1); + } + + @Test + public void dummycodeReconstructsOriginalParallel() { + roundTrip("{ids:true, recode:[1], dummycode:[1]}", false, 4); + } + + /** + * Binning a column while a different column is dummycoded shifts the bin column's source position in the encoded + * matrix. The bin decoder must rebuild that source-column mapping from the dummycode domain sizes. This asserts the + * dense, sparse, and parallel decode paths agree for that layout (bin output is lossy, so exact reconstruction is + * not asserted, only cross-mode consistency and dimensions). + */ + @Test + public void binWithDummycodeOnOtherColumnConsistency() { + final String spec = "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[2]}"; + try { + final FrameBlock original = TestUtils.generateRandomFrameBlock(150, + new ValueType[] {ValueType.FP32, ValueType.UINT4, ValueType.UINT8}, 4242); + final String[] colnames = original.getColumnNames(); + + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + final MatrixBlock encoded = encoder.encode(original, 1); + final FrameBlock meta = encoder.getMetaData(null); + + final MatrixBlock dense = new MatrixBlock(); + dense.copy(encoded); + if(dense.isInSparseFormat()) + dense.sparseToDense(); + + final MatrixBlock sparse = new MatrixBlock(); + sparse.copy(encoded); + if(!sparse.isInSparseFormat()) + sparse.denseToSparse(); + + final FrameBlock reference = decodeOnce(spec, colnames, meta, dense, 1); + final FrameBlock parallel = decodeOnce(spec, colnames, meta, dense, 4); + final FrameBlock fromSparse = decodeOnce(spec, colnames, meta, sparse, 1); + + org.junit.Assert.assertEquals(original.getNumRows(), reference.getNumRows()); + TestUtils.compareFrames(reference, parallel, false); + TestUtils.compareFrames(reference, fromSparse, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + private static FrameBlock decodeOnce(String spec, String[] colnames, FrameBlock meta, MatrixBlock in, int k) { + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, in.getNumColumns()); + return decoder.decode(in, new FrameBlock(decoder.getSchema()), k); + } + + private void roundTrip(String spec, boolean sparse, int k) { + try { + final FrameBlock original = categoricalFrame(); + final String[] colnames = original.getColumnNames(); + + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + MatrixBlock encoded = encoder.encode(original, 1); + final FrameBlock meta = encoder.getMetaData(null); + + if(sparse && !encoded.isInSparseFormat()) + encoded.denseToSparse(); + else if(!sparse && encoded.isInSparseFormat()) + encoded.sparseToDense(); + + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, encoded.getNumColumns()); + final FrameBlock decoded = decoder.decode(encoded, new FrameBlock(decoder.getSchema()), k); + + TestUtils.compareFrames(original, decoded, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " (sparse=" + sparse + ", k=" + k + ") : " + e.getMessage()); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java new file mode 100644 index 00000000000..254937c20da --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame.transform; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.logging.Level; +import java.util.logging.Logger; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.transform.decode.Decoder; +import org.apache.sysds.runtime.transform.decode.DecoderFactory; +import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; +import org.apache.sysds.runtime.util.CommonThreadPool; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +/** + * Component tests for the transform decoders. These exercise the row-block and parallel decode paths, the sparse and + * dense dummycode decode paths, the binning source-column offset mapping, and feature-hash column handling end-to-end + * through an encode followed by decode round trip. + */ +@RunWith(value = Parameterized.class) +public class TransformDecodeTest { + protected static final Log LOG = LogFactory.getLog(TransformDecodeTest.class.getName()); + + private final FrameBlock data; + private final int k; + + public TransformDecodeTest(FrameBlock data, int k) { + // name must contain "main" so the parallel decode path reuses the shared thread pool + Thread.currentThread().setName("main_test_decode"); + Logger.getLogger(CommonThreadPool.class.getName()).setLevel(Level.OFF); + this.data = data; + this.k = k; + } + + @Parameters + public static Collection data() { + final ArrayList tests = new ArrayList<>(); + final int[] threads = new int[] {1, 4}; + try { + final FrameBlock[] blocks = new FrameBlock[] { + // single low-cardinality categorical column + TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231), + // single categorical column with nulls + TestUtils.generateRandomFrameBlock(64, new ValueType[] {ValueType.UINT4}, 99, 0.2), + // multi column: dummycode/bin on col1 must offset the trailing passthrough columns + TestUtils.generateRandomFrameBlock(120, + new ValueType[] {ValueType.UINT4, ValueType.UINT8, ValueType.FP32}, 17), + // large enough to split into multiple row blocks in the parallel decode path + TestUtils.generateRandomFrameBlock(2500, new ValueType[] {ValueType.UINT4}, 7)}; + + for(FrameBlock block : blocks) + for(int k : threads) + tests.add(new Object[] {block, k}); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + return tests; + } + + @Test + public void testPassThrough() { + decodeConsistency("{ids:true}"); + } + + @Test + public void testRecode() { + decodeConsistency("{ids:true, recode:[1]}"); + } + + @Test + public void testDummycode() { + decodeConsistency("{ids:true, recode:[1], dummycode:[1]}"); + } + + @Test + public void testBinWidth() { + decodeConsistency("{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}"); + } + + @Test + public void testBinHeight() { + decodeConsistency("{ids:true, bin:[{id:1, method:equi-height, numbins:10}]}"); + } + + @Test + public void testBinSingleBin() { + // numbins:1 forces the key==0 branch in the bin decoder + decodeConsistency("{ids:true, bin:[{id:1, method:equi-width, numbins:1}]}"); + } + + @Test + public void testHashToDummy() { + // feature-hash columns carry their domain size as the magic "¿K" metadata value, which the dummycode decoder + // must parse to reconstruct the one-hot column ranges + decodeConsistency("{ids:true, hash:[1], K:8, dummycode:[1]}"); + } + + @Test + public void testHashToDummyDomain1() { + decodeConsistency("{ids:true, hash:[1], K:1, dummycode:[1]}"); + } + + /** + * Encode the data, then decode the encoded matrix in three ways: serial dense, parallel dense, and serial sparse. + * All three must produce identical frames. This jointly exercises the parallel block-decode path in + * {@link Decoder#decode(MatrixBlock, FrameBlock, int)} and the separate sparse / dense dummycode decode paths. + */ + private void decodeConsistency(String spec) { + try { + final String[] colnames = data.getColumnNames(); + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, data.getNumColumns(), null); + final MatrixBlock encoded = encoder.encode(data, 1); + final FrameBlock meta = encoder.getMetaData(null); + + final MatrixBlock dense = forceDense(encoded); + final MatrixBlock sparse = forceSparse(encoded); + + final FrameBlock reference = decode(spec, colnames, meta, dense, 1); + final FrameBlock parallel = decode(spec, colnames, meta, dense, k); + final FrameBlock fromSparse = decode(spec, colnames, meta, sparse, 1); + + assertEquals("decoded rows must match input rows", data.getNumRows(), reference.getNumRows()); + + TestUtils.compareFrames(reference, parallel, false); + TestUtils.compareFrames(reference, fromSparse, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + private static FrameBlock decode(String spec, String[] colnames, FrameBlock meta, MatrixBlock in, int k) { + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, in.getNumColumns()); + return decoder.decode(in, new FrameBlock(decoder.getSchema()), k); + } + + private static MatrixBlock forceDense(MatrixBlock in) { + final MatrixBlock out = new MatrixBlock(); + out.copy(in); + if(out.isInSparseFormat()) + out.sparseToDense(); + return out; + } + + private static MatrixBlock forceSparse(MatrixBlock in) { + final MatrixBlock out = new MatrixBlock(); + out.copy(in); + if(!out.isInSparseFormat()) + out.denseToSparse(); + return out; + } +}