From 5e218cc162adbbb47dc5e2a7ba275f9f21d84a6e Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Fri, 29 May 2026 14:06:10 +0000 Subject: [PATCH 1/4] Handle hash columns in transform decoders and tighten decode metadata flow Reworks transform decoders so that hash-encoded columns survive the inverse-transform path, and tightens how decoder metadata (column indices, value mappings) is propagated and initialized. - Decoder: pass column-id arrays through decode/decodeFromMap so each decoder knows its own output column range - DecoderRecode: skip recode for hash columns, keep encoded ints passthrough; init metadata from frame consistently - DecoderDummycode: handle hash columns when expanding categorical bits; parallel decode path; sparse-friendly init - DecoderPassThrough / DecoderBin / DecoderComposite / DecoderFactory: consume the new column-id arrays from the dispatch layer - ColumnEncoderFeatureHash: align hash bookkeeping with the decode-side changes - Frame columns (HashMapToInt, StringArray): small support changes consumed by the decoder path above --- .../frame/data/columns/StringArray.java | 4 +- .../runtime/transform/decode/Decoder.java | 32 ++++- .../runtime/transform/decode/DecoderBin.java | 68 ++++++++-- .../transform/decode/DecoderComposite.java | 32 +---- .../transform/decode/DecoderDummycode.java | 117 ++++++++++++------ .../transform/decode/DecoderFactory.java | 28 ++++- .../transform/decode/DecoderPassThrough.java | 23 ++-- .../transform/decode/DecoderRecode.java | 49 ++++---- .../encode/ColumnEncoderFeatureHash.java | 6 +- 9 files changed, 240 insertions(+), 119 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index 1fc582924e4..292fcb52bf5 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -607,7 +607,6 @@ public double getAsNaNDouble(int i) { private static double getAsDouble(String s) { try { - return DoubleArray.parseDouble(s); } catch(Exception e) { @@ -617,7 +616,8 @@ private static double getAsDouble(String s) { else if(ls.equals("false") || ls.equals("f")) return 0; else - throw new DMLRuntimeException("Unable to change to double: " + s, e); + throw e; // for efficiency + // throw new DMLRuntimeException("Unable to change to double: " + s, e); } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java index 724af1be630..70834675ded 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java @@ -23,6 +23,10 @@ import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -30,6 +34,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.CommonThreadPool; /** * Base class for all transform decoders providing both a row and block @@ -77,8 +82,31 @@ public String[] getColnames() { * @param k Parallelization degree * @return returns the given output frame block for convenience */ - public FrameBlock decode(MatrixBlock in, FrameBlock out, int k) { - return decode(in, out); + public FrameBlock decode(final MatrixBlock in, final FrameBlock out, final int k) { + if(k <= 1) + return decode(in, out); + final ExecutorService pool = CommonThreadPool.get(k); + out.ensureAllocatedColumns(in.getNumRows()); + try { + final List> tasks = new ArrayList<>(); + int blz = Math.max((in.getNumRows() + k) / k, 1000); + + for(int i = 0; i < in.getNumRows(); i += blz){ + final int start = i; + final int end = Math.min(in.getNumRows(), i + blz); + tasks.add(pool.submit(() -> decode(in, out, start, end))); + } + + for(Future f : tasks) + f.get(); + return out; + } + catch(Exception e) { + throw new RuntimeException(e); + } + finally { + pool.shutdown(); + } } /** diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java index edee095f612..c9fcc23990a 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java @@ -28,6 +28,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.UtilFunctions; @@ -43,15 +44,18 @@ public class DecoderBin extends Decoder { // a) column bin boundaries private int[] _numBins; + private int[] _dcCols = null; + private int[] _srcCols = null; private double[][] _binMins = null; private double[][] _binMaxs = null; - public DecoderBin() { - super(null, null); - } + // public DecoderBin() { + // super(null, null); + // } - protected DecoderBin(ValueType[] schema, int[] binCols) { + protected DecoderBin(ValueType[] schema, int[] binCols, int[] dcCols) { super(schema, binCols); + _dcCols = dcCols; } @Override @@ -66,14 +70,28 @@ public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { for( int i=rl; i< ru; i++ ) { for( int j=0; j<_colList.length; j++ ) { final Array a = out.getColumn(_colList[j] - 1); - final double val = in.get(i, _colList[j] - 1); + final double val = in.get(i, _srcCols[j] - 1); if(!Double.isNaN(val)){ - final int key = (int) Math.round(val); - double bmin = _binMins[j][key - 1]; - double bmax = _binMaxs[j][key - 1]; - double oval = bmin + (bmax - bmin) / 2 // bin center - + (val - key) * (bmax - bmin); // bin fractions - a.set(i, oval); + try{ + + final int key = (int) Math.round(val); + if(key == 0){ + a.set(i, _binMins[j][key]); + } + else{ + double bmin = _binMins[j][key - 1]; + double bmax = _binMaxs[j][key - 1]; + double oval = bmin + (bmax - bmin) / 2 // bin center + + (val - key) * (bmax - bmin); // bin fractions + a.set(i, oval); + } + } + catch(Exception e){ + LOG.error(a); + LOG.error(in.slice(0, in.getNumRows()-1, _colList[j]-1,_colList[j]-1)); + LOG.error( val); + throw e; + } } else a.set(i, val); // NaN @@ -111,6 +129,34 @@ public void initMetaData(FrameBlock meta) { _binMaxs[j][i] = Double.parseDouble(parts[1]); } } + + + if( _dcCols.length > 0 ) { + //prepare source column id mapping w/ dummy coding + _srcCols = new int[_colList.length]; + int ix1 = 0, ix2 = 0, off = 0; + while( ix1<_colList.length ) { + if( ix2>=_dcCols.length || _colList[ix1] < _dcCols[ix2] ) { + _srcCols[ix1] = _colList[ix1] + off; + ix1 ++; + } + else { //_colList[ix1] > _dcCols[ix2] + ColumnMetadata d =meta.getColumnMetadata()[_dcCols[ix2]-1]; + String v = meta.getString(0, _dcCols[ix2]-1); + if(v.length() > 1 && v.charAt(0) == '¿'){ + off += UtilFunctions.parseToLong(v.substring(1)) -1; + } + else { + off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + } + ix2 ++; + } + } + } + else { + //prepare direct source column mapping + _srcCols = _colList; + } } @Override diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java index f4bc9f8b216..dff85e72dc6 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java @@ -25,13 +25,10 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.util.CommonThreadPool; /** * Simple composite decoder that applies a list of decoders @@ -50,7 +47,7 @@ protected DecoderComposite(ValueType[] schema, List decoders) { _decoders = decoders; } - public DecoderComposite() { super(null, null); } + // public DecoderComposite() { super(null, null); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { @@ -59,33 +56,6 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { return out; } - - @Override - public FrameBlock decode(final MatrixBlock in, final FrameBlock out, final int k) { - final ExecutorService pool = CommonThreadPool.get(k); - out.ensureAllocatedColumns(in.getNumRows()); - try { - final List> tasks = new ArrayList<>(); - int blz = Math.max(in.getNumRows() / k, 1000); - for(Decoder decoder : _decoders){ - for(int i = 0; i < in.getNumRows(); i += blz){ - final int start = i; - final int end = Math.min(in.getNumRows(), i + blz); - tasks.add(pool.submit(() -> decoder.decode(in, out, start, end))); - } - } - for(Future f : tasks) - f.get(); - return out; - } - catch(Exception e) { - throw new RuntimeException(e); - } - finally { - pool.shutdown(); - } - } - @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru){ for( Decoder decoder : _decoders ) diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java index 0c4c6b42690..debce027680 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java @@ -27,31 +27,30 @@ import java.util.List; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.UtilFunctions; /** - * Simple atomic decoder for dummycoded columns. This decoder builds internally - * inverted column mappings from the given frame meta data. - * + * Simple atomic decoder for dummycoded columns. This decoder builds internally inverted column mappings from the given + * frame meta data. + * */ -public class DecoderDummycode extends Decoder -{ +public class DecoderDummycode extends Decoder { private static final long serialVersionUID = 4758831042891032129L; - + private int[] _clPos = null; private int[] _cuPos = null; - + protected DecoderDummycode(ValueType[] schema, int[] dcCols) { - //dcCols refers to column IDs in output (non-dc) + // dcCols refers to column IDs in output (non-dc) super(schema, dcCols); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { - //TODO perf (exploit sparse representation for better asymptotic behavior) out.ensureAllocatedColumns(in.getNumRows()); decode(in, out, 0, in.getNumRows()); return out; @@ -59,59 +58,98 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { - //TODO perf (exploit sparse representation for better asymptotic behavior) - // out.ensureAllocatedColumns(in.getNumRows()); - for( int i=rl; i= low && aix[h] < high) { + int k = aix[h]; + int col = _colList[j] - 1; + out.getColumn(col).set(i, k - _clPos[j] + 1); + } + // limit the binary search. + apos = h; + } + + } + @Override public Decoder subRangeDecoder(int colStart, int colEnd, int dummycodedOffset) { List dcList = new ArrayList<>(); List clPosList = new ArrayList<>(); List cuPosList = new ArrayList<>(); - + // get the column IDs for the sub range of the dummycode columns and their destination positions, // where they will be decoded to - for( int j=0; j<_colList.length; j++ ) { + for(int j = 0; j < _colList.length; j++) { int colID = _colList[j]; - if (colID >= colStart && colID < colEnd) { + if(colID >= colStart && colID < colEnd) { dcList.add(colID - (colStart - 1)); clPosList.add(_clPos[j] - dummycodedOffset); cuPosList.add(_cuPos[j] - dummycodedOffset); } } - if (dcList.isEmpty()) + if(dcList.isEmpty()) return null; // create sub-range decoder int[] colList = dcList.stream().mapToInt(i -> i).toArray(); - DecoderDummycode subRangeDecoder = new DecoderDummycode( - Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), colList); + DecoderDummycode subRangeDecoder = new DecoderDummycode(Arrays.copyOfRange(_schema, colStart - 1, colEnd - 1), + colList); subRangeDecoder._clPos = clPosList.stream().mapToInt(i -> i).toArray(); subRangeDecoder._cuPos = cuPosList.stream().mapToInt(i -> i).toArray(); return subRangeDecoder; } - + @Override public void updateIndexRanges(long[] beginDims, long[] endDims) { if(_colList == null) return; - + long lowerColDest = beginDims[1]; long upperColDest = endDims[1]; for(int i = 0; i < _colList.length; i++) { long numDistinct = _cuPos[i] - _clPos[i]; - + if(_cuPos[i] <= beginDims[1] + 1) if(numDistinct > 0) lowerColDest -= numDistinct - 1; - + if(_cuPos[i] <= endDims[1] + 1) if(numDistinct > 0) upperColDest -= numDistinct - 1; @@ -119,16 +157,25 @@ public void updateIndexRanges(long[] beginDims, long[] endDims) { beginDims[1] = lowerColDest; endDims[1] = upperColDest; } - + @Override public void initMetaData(FrameBlock meta) { - _clPos = new int[_colList.length]; //col lower pos - _cuPos = new int[_colList.length]; //col upper pos - for( int j=0, off=0; j<_colList.length; j++ ) { + _clPos = new int[_colList.length]; // col lower pos + _cuPos = new int[_colList.length]; // col upper pos + for(int j = 0, off = 0; j < _colList.length; j++) { int colID = _colList[j]; - ColumnMetadata d = meta.getColumnMetadata()[colID-1]; - int ndist = d.isDefault() ? 0 : (int)d.getNumDistinct(); - ndist = ndist < -1 ? 0: ndist; + ColumnMetadata d = meta.getColumnMetadata()[colID - 1]; + String v = meta.getString(0, colID - 1); + int ndist; + if(v.length() > 1 && v.charAt(0) == '¿') { + ndist = UtilFunctions.parseToInt(v.substring(1)); + } + else { + ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); + } + + ndist = ndist < -1 ? 0 : ndist; // safety if all values was null. + _clPos[j] = off + colID; _cuPos[j] = _clPos[j] + ndist; off += ndist - 1; diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java index 0a400e6da92..12ba2968877 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java @@ -64,34 +64,52 @@ public static Decoder createDecoder(String spec, String[] colnames, ValueType[] try { //parse transform specification JSONObject jSpec = new JSONObject(spec); - List ldecoders = new ArrayList<>(); - //create decoders 'bin', 'recode', 'dummy' and 'pass-through' + //create decoders 'bin', 'recode', 'hash', 'dummy', and 'pass-through' List binIDs = TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol); List rcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.RECODE.toString(), minCol, maxCol))); + List hcIDs = Arrays.asList(ArrayUtils.toObject( + TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.HASH.toString(), minCol, maxCol))); List dcIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonIDList(jSpec, colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol))); + // only specially treat the columns with both recode and dictionary rcIDs = unionDistinct(rcIDs, dcIDs); + // remove hash recoded. // todo potentially wrong and remove? + rcIDs = except(rcIDs, hcIDs); + int len = dcIDs.isEmpty() ? Math.min(meta.getNumColumns(), clen) : meta.getNumColumns(); - List ptIDs = except(except(UtilFunctions.getSeqList(1, len, 1), rcIDs), binIDs); - + + // set the remaining columns to passthrough. + List ptIDs = UtilFunctions.getSeqList(1, len, 1); + // except recoded columns + ptIDs = except(ptIDs, rcIDs); + // binned columns + ptIDs = except(ptIDs, binIDs); + // hashed columns + ptIDs = except(ptIDs, hcIDs); // remove hashed columns + //create default schema if unspecified (with double columns for pass-through) if( schema == null ) { schema = UtilFunctions.nCopies(len, ValueType.STRING); for( Integer col : ptIDs ) schema[col-1] = ValueType.FP64; } + + // collect all the decoders in one list. + List ldecoders = new ArrayList<>(); if( !binIDs.isEmpty() ) { ldecoders.add(new DecoderBin(schema, - ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0])))); + ArrayUtils.toPrimitive(binIDs.toArray(new Integer[0])), + ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); } if( !dcIDs.isEmpty() ) { ldecoders.add(new DecoderDummycode(schema, ArrayUtils.toPrimitive(dcIDs.toArray(new Integer[0])))); } if( !rcIDs.isEmpty() ) { + // todo figure out if we need to handle rc columns with regards to dictionary offsets. ldecoders.add(new DecoderRecode(schema, !dcIDs.isEmpty(), ArrayUtils.toPrimitive(rcIDs.toArray(new Integer[0])))); } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java index 5b6bf7a093e..c2de3ec1df3 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java @@ -49,7 +49,7 @@ protected DecoderPassThrough(ValueType[] schema, int[] ptCols, int[] dcCols) { _dcCols = dcCols; } - public DecoderPassThrough() { super(null, null); } + // public DecoderPassThrough() { super(null, null); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { @@ -61,13 +61,12 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { @Override public void decode(MatrixBlock in, FrameBlock out, int rl, int ru) { int clen = Math.min(_colList.length, out.getNumColumns()); - for( int i=rl; i _dcCols[ix2] ColumnMetadata d =meta.getColumnMetadata()[_dcCols[ix2]-1]; - off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + String v = meta.getString( 0,_dcCols[ix2]-1); + if(v.length() > 1 && v.charAt(0) == '¿'){ + off += UtilFunctions.parseToLong(v.substring(1)) -1; + } + else { + off += d.isDefault() ? -1 : d.getNumDistinct() - 1; + } ix2 ++; } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java index 33459a1c4f9..1cf0b7c4b3f 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java @@ -29,6 +29,7 @@ import java.util.Map.Entry; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.data.Pair; @@ -46,12 +47,11 @@ public class DecoderRecode extends Decoder private static final long serialVersionUID = -3784249774608228805L; private HashMap[] _rcMaps = null; - private Object[][] _rcMapsDirect = null; private boolean _onOut = false; - public DecoderRecode() { - super(null, null); - } + // public DecoderRecode() { + // super(null, null); + // } protected DecoderRecode(ValueType[] schema, boolean onOut, int[] rcCols) { super(schema, rcCols); @@ -59,8 +59,7 @@ protected DecoderRecode(ValueType[] schema, boolean onOut, int[] rcCols) { } public Object getRcMapValue(int i, long key) { - return (_rcMapsDirect != null && key > 0) ? - _rcMapsDirect[i][(int)key-1] : _rcMaps[i].get(key); + return _rcMaps[i].get(key); } @Override @@ -129,27 +128,33 @@ public void initMetaData(FrameBlock meta) { for( int j=0; j<_colList.length; j++ ) { HashMap map = new HashMap<>(); for( int i=0; i v < Integer.MAX_VALUE) ) { - _rcMapsDirect = new Object[_rcMaps.length][]; - for( int i=0; i<_rcMaps.length; i++ ) { - Object[] arr = new Object[(int)max[i]]; - for(Entry e1 : _rcMaps[i].entrySet()) - arr[e1.getKey().intValue()-1] = e1.getValue(); - _rcMapsDirect[i] = arr; - } - } + // if( Arrays.stream(max).allMatch(v -> v < Integer.MAX_VALUE) ) { + // _rcMapsDirect = new Object[_rcMaps.length][]; + // for( int i=0; i<_rcMaps.length; i++ ) { + // Object[] arr = new Object[(int)max[i]]; + // for(Entry e1 : _rcMaps[i].entrySet()) + // arr[e1.getKey().intValue()-1] = e1.getValue(); + // _rcMapsDirect[i] = arr; + // } + // } } /** diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java index 400b7f64ffc..361c9c52135 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java @@ -146,7 +146,9 @@ public FrameBlock getMetaData(FrameBlock meta) { return meta; meta.ensureAllocatedColumns(1); - meta.set(0, _colID - 1, String.valueOf(_K)); + // set metadata of hash columns to magical hash value + k + meta.set(0, _colID - 1, String.format("¿%d" , _K)); + return meta; } @@ -154,7 +156,7 @@ public FrameBlock getMetaData(FrameBlock meta) { public void initMetaData(FrameBlock meta) { if(meta == null || meta.getNumRows() <= 0) return; - _K = UtilFunctions.parseToLong(meta.get(0, _colID - 1).toString()); + _K = UtilFunctions.parseToLong(meta.getString(0, _colID - 1).substring(1)); } @Override From 566e2984ad4df645ea73600596988eed7fc09b6f Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 8 Jun 2026 14:57:57 +0000 Subject: [PATCH 2/4] Fix sparse dummycode decode and restore decoder no-arg constructors Fix two regressions in the transform decode rewrite that broke encode/decode roundtrips on dummycoded/recoded frames: - DecoderDummycode.decodeSparse compared 0-based sparse column indexes against the 1-based _clPos/_cuPos bounds used by the dense path (in.get(i, k-1)). This shifted every lookup by one column, so the first category was never matched (decoded as null) and all others decoded one code too low. Shift the sparse bounds and index to be 0-based, matching the dense path. - Restore the public no-arg constructors on DecoderComposite, DecoderBin, DecoderPassThrough, and DecoderRecode. Decoder is Externalizable, and Spark broadcasts the top-level decoder via Java serialization, which requires a public no-arg constructor; without it deserialization fails with InvalidClassException on executors. Restores passing of TransformFrameEncodeColmapTest, TransformFrameEncodeDecodeTest, TransformCSVFrameEncodeDecodeTest, and TransformFrameEncodeDecodeTokenTest in single-node and Spark modes. --- .../sysds/runtime/transform/decode/DecoderBin.java | 6 +++--- .../runtime/transform/decode/DecoderComposite.java | 2 +- .../runtime/transform/decode/DecoderDummycode.java | 10 ++++++---- .../runtime/transform/decode/DecoderPassThrough.java | 2 +- .../sysds/runtime/transform/decode/DecoderRecode.java | 6 +++--- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java index c9fcc23990a..79d9b7f3a40 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderBin.java @@ -49,9 +49,9 @@ public class DecoderBin extends Decoder { private double[][] _binMins = null; private double[][] _binMaxs = null; - // public DecoderBin() { - // super(null, null); - // } + public DecoderBin() { + super(null, null); + } protected DecoderBin(ValueType[] schema, int[] binCols, int[] dcCols) { super(schema, binCols); diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java index dff85e72dc6..adfef7bbc6d 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java @@ -47,7 +47,7 @@ protected DecoderComposite(ValueType[] schema, List decoders) { _decoders = decoders; } - // public DecoderComposite() { super(null, null); } + public DecoderComposite() { super(null, null); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java index debce027680..95d7f4fa4c9 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java @@ -91,9 +91,11 @@ private void decodeSparseRow(FrameBlock out, final SparseBlock sb, int i) { final int[] aix = sb.indexes(i); for(int j = 0; j < _colList.length; j++) { // for each decode column. - // find k, the index in aix, within the range of low and high - final int low = _clPos[j]; - final int high = _cuPos[j]; + // find k, the index in aix, within the range of low and high. + // _clPos/_cuPos are 1-based matrix positions (the dense path reads + // in.get(i, k-1)); the sparse indexes in aix are 0-based, so shift. + final int low = _clPos[j] - 1; + final int high = _cuPos[j] - 1; int h = Arrays.binarySearch(aix, apos, alen, low); // start h at column. if(h < 0) // search gt col index (see binary search) h = Math.abs(h + 1); @@ -101,7 +103,7 @@ private void decodeSparseRow(FrameBlock out, final SparseBlock sb, int i) { if(h < alen && aix[h] >= low && aix[h] < high) { int k = aix[h]; int col = _colList[j] - 1; - out.getColumn(col).set(i, k - _clPos[j] + 1); + out.getColumn(col).set(i, k - low + 1); } // limit the binary search. apos = h; diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java index c2de3ec1df3..d2e7d59e81f 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java @@ -49,7 +49,7 @@ protected DecoderPassThrough(ValueType[] schema, int[] ptCols, int[] dcCols) { _dcCols = dcCols; } - // public DecoderPassThrough() { super(null, null); } + public DecoderPassThrough() { super(null, null); } @Override public FrameBlock decode(MatrixBlock in, FrameBlock out) { diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java index 1cf0b7c4b3f..a48759493fa 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java @@ -49,9 +49,9 @@ public class DecoderRecode extends Decoder private HashMap[] _rcMaps = null; private boolean _onOut = false; - // public DecoderRecode() { - // super(null, null); - // } + public DecoderRecode() { + super(null, null); + } protected DecoderRecode(ValueType[] schema, boolean onOut, int[] rcCols) { super(schema, rcCols); From 6b4e2fb019357106997ea527f79f448ac747fdcc Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 9 Jun 2026 13:55:49 +0000 Subject: [PATCH 3/4] Add component tests for transform decoder hash and metadata handling Cover the decoder paths touched by the hash-column and decode-metadata changes: parallel block decode equals serial decode, the sparse and dense dummycode decode paths agree, feature-hash columns decode through dummycode via the magic domain-size metadata, and bin columns whose source position is shifted by dummycoding of another column. Add exact inverse round-trip checks for recode and dummycode to validate the sparse binary-search decode against ground truth. --- .../TransformDecodeRoundTripTest.java | 167 ++++++++++++++++ .../frame/transform/TransformDecodeTest.java | 186 ++++++++++++++++++ 2 files changed, 353 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java new file mode 100644 index 00000000000..9e8b55df29b --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeRoundTripTest.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame.transform; + +import static org.junit.Assert.fail; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.transform.decode.Decoder; +import org.apache.sysds.runtime.transform.decode.DecoderFactory; +import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; +import org.apache.sysds.test.TestUtils; +import org.junit.Before; +import org.junit.Test; + +/** + * Exact inverse correctness tests for the transform decoders. Recode and dummycode are lossless category encodings, so a + * decode of the encoded matrix must reconstruct the original categorical frame. These tests assert exact reconstruction + * for the dense path, the sparse path, and the parallel path so that the dummycode sparse binary search and the parallel + * block split are validated against ground truth rather than only against each other. + */ +public class TransformDecodeRoundTripTest { + protected static final Log LOG = LogFactory.getLog(TransformDecodeRoundTripTest.class.getName()); + + @Before + public void setUp() { + // name must contain "main" so the parallel decode path reuses the shared thread pool + Thread.currentThread().setName("main_test_decode"); + } + + private static FrameBlock categoricalFrame() { + final String[] values = new String[] { + "apple", "banana", "apple", "cherry", "banana", "date", "apple", "cherry", "date", "banana", "elderberry", + "apple", "fig", "banana", "cherry", "apple", "date", "fig", "elderberry", "banana"}; + final FrameBlock f = new FrameBlock(new ValueType[] {ValueType.STRING}); + f.ensureAllocatedColumns(values.length); + for(int i = 0; i < values.length; i++) + f.set(i, 0, values[i]); + return f; + } + + @Test + public void recodeReconstructsOriginalDense() { + roundTrip("{ids:true, recode:[1]}", false, 1); + } + + @Test + public void recodeReconstructsOriginalSparse() { + roundTrip("{ids:true, recode:[1]}", true, 1); + } + + @Test + public void recodeReconstructsOriginalParallel() { + roundTrip("{ids:true, recode:[1]}", false, 4); + } + + @Test + public void dummycodeReconstructsOriginalDense() { + roundTrip("{ids:true, recode:[1], dummycode:[1]}", false, 1); + } + + @Test + public void dummycodeReconstructsOriginalSparse() { + // the one-hot encoded matrix is sparse, so this drives the dummycode sparse binary-search decode path + roundTrip("{ids:true, recode:[1], dummycode:[1]}", true, 1); + } + + @Test + public void dummycodeReconstructsOriginalParallel() { + roundTrip("{ids:true, recode:[1], dummycode:[1]}", false, 4); + } + + /** + * Binning a column while a different column is dummycoded shifts the bin column's source position in the encoded + * matrix. The bin decoder must rebuild that source-column mapping from the dummycode domain sizes. This asserts the + * dense, sparse, and parallel decode paths agree for that layout (bin output is lossy, so exact reconstruction is + * not asserted, only cross-mode consistency and dimensions). + */ + @Test + public void binWithDummycodeOnOtherColumnConsistency() { + final String spec = "{ids:true, bin:[{id:1, method:equi-width, numbins:4}], dummycode:[2]}"; + try { + final FrameBlock original = TestUtils.generateRandomFrameBlock(150, + new ValueType[] {ValueType.FP32, ValueType.UINT4, ValueType.UINT8}, 4242); + final String[] colnames = original.getColumnNames(); + + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + final MatrixBlock encoded = encoder.encode(original, 1); + final FrameBlock meta = encoder.getMetaData(null); + + final MatrixBlock dense = new MatrixBlock(); + dense.copy(encoded); + if(dense.isInSparseFormat()) + dense.sparseToDense(); + + final MatrixBlock sparse = new MatrixBlock(); + sparse.copy(encoded); + if(!sparse.isInSparseFormat()) + sparse.denseToSparse(); + + final FrameBlock reference = decodeOnce(spec, colnames, meta, dense, 1); + final FrameBlock parallel = decodeOnce(spec, colnames, meta, dense, 4); + final FrameBlock fromSparse = decodeOnce(spec, colnames, meta, sparse, 1); + + org.junit.Assert.assertEquals(original.getNumRows(), reference.getNumRows()); + TestUtils.compareFrames(reference, parallel, false); + TestUtils.compareFrames(reference, fromSparse, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + private static FrameBlock decodeOnce(String spec, String[] colnames, FrameBlock meta, MatrixBlock in, int k) { + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, in.getNumColumns()); + return decoder.decode(in, new FrameBlock(decoder.getSchema()), k); + } + + private void roundTrip(String spec, boolean sparse, int k) { + try { + final FrameBlock original = categoricalFrame(); + final String[] colnames = original.getColumnNames(); + + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, original.getNumColumns(), + null); + MatrixBlock encoded = encoder.encode(original, 1); + final FrameBlock meta = encoder.getMetaData(null); + + if(sparse && !encoded.isInSparseFormat()) + encoded.denseToSparse(); + else if(!sparse && encoded.isInSparseFormat()) + encoded.sparseToDense(); + + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, encoded.getNumColumns()); + final FrameBlock decoded = decoder.decode(encoded, new FrameBlock(decoder.getSchema()), k); + + TestUtils.compareFrames(original, decoded, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " (sparse=" + sparse + ", k=" + k + ") : " + e.getMessage()); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java new file mode 100644 index 00000000000..254937c20da --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformDecodeTest.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame.transform; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.logging.Level; +import java.util.logging.Logger; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.transform.decode.Decoder; +import org.apache.sysds.runtime.transform.decode.DecoderFactory; +import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; +import org.apache.sysds.runtime.util.CommonThreadPool; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +/** + * Component tests for the transform decoders. These exercise the row-block and parallel decode paths, the sparse and + * dense dummycode decode paths, the binning source-column offset mapping, and feature-hash column handling end-to-end + * through an encode followed by decode round trip. + */ +@RunWith(value = Parameterized.class) +public class TransformDecodeTest { + protected static final Log LOG = LogFactory.getLog(TransformDecodeTest.class.getName()); + + private final FrameBlock data; + private final int k; + + public TransformDecodeTest(FrameBlock data, int k) { + // name must contain "main" so the parallel decode path reuses the shared thread pool + Thread.currentThread().setName("main_test_decode"); + Logger.getLogger(CommonThreadPool.class.getName()).setLevel(Level.OFF); + this.data = data; + this.k = k; + } + + @Parameters + public static Collection data() { + final ArrayList tests = new ArrayList<>(); + final int[] threads = new int[] {1, 4}; + try { + final FrameBlock[] blocks = new FrameBlock[] { + // single low-cardinality categorical column + TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231), + // single categorical column with nulls + TestUtils.generateRandomFrameBlock(64, new ValueType[] {ValueType.UINT4}, 99, 0.2), + // multi column: dummycode/bin on col1 must offset the trailing passthrough columns + TestUtils.generateRandomFrameBlock(120, + new ValueType[] {ValueType.UINT4, ValueType.UINT8, ValueType.FP32}, 17), + // large enough to split into multiple row blocks in the parallel decode path + TestUtils.generateRandomFrameBlock(2500, new ValueType[] {ValueType.UINT4}, 7)}; + + for(FrameBlock block : blocks) + for(int k : threads) + tests.add(new Object[] {block, k}); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + return tests; + } + + @Test + public void testPassThrough() { + decodeConsistency("{ids:true}"); + } + + @Test + public void testRecode() { + decodeConsistency("{ids:true, recode:[1]}"); + } + + @Test + public void testDummycode() { + decodeConsistency("{ids:true, recode:[1], dummycode:[1]}"); + } + + @Test + public void testBinWidth() { + decodeConsistency("{ids:true, bin:[{id:1, method:equi-width, numbins:4}]}"); + } + + @Test + public void testBinHeight() { + decodeConsistency("{ids:true, bin:[{id:1, method:equi-height, numbins:10}]}"); + } + + @Test + public void testBinSingleBin() { + // numbins:1 forces the key==0 branch in the bin decoder + decodeConsistency("{ids:true, bin:[{id:1, method:equi-width, numbins:1}]}"); + } + + @Test + public void testHashToDummy() { + // feature-hash columns carry their domain size as the magic "¿K" metadata value, which the dummycode decoder + // must parse to reconstruct the one-hot column ranges + decodeConsistency("{ids:true, hash:[1], K:8, dummycode:[1]}"); + } + + @Test + public void testHashToDummyDomain1() { + decodeConsistency("{ids:true, hash:[1], K:1, dummycode:[1]}"); + } + + /** + * Encode the data, then decode the encoded matrix in three ways: serial dense, parallel dense, and serial sparse. + * All three must produce identical frames. This jointly exercises the parallel block-decode path in + * {@link Decoder#decode(MatrixBlock, FrameBlock, int)} and the separate sparse / dense dummycode decode paths. + */ + private void decodeConsistency(String spec) { + try { + final String[] colnames = data.getColumnNames(); + final MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, colnames, data.getNumColumns(), null); + final MatrixBlock encoded = encoder.encode(data, 1); + final FrameBlock meta = encoder.getMetaData(null); + + final MatrixBlock dense = forceDense(encoded); + final MatrixBlock sparse = forceSparse(encoded); + + final FrameBlock reference = decode(spec, colnames, meta, dense, 1); + final FrameBlock parallel = decode(spec, colnames, meta, dense, k); + final FrameBlock fromSparse = decode(spec, colnames, meta, sparse, 1); + + assertEquals("decoded rows must match input rows", data.getNumRows(), reference.getNumRows()); + + TestUtils.compareFrames(reference, parallel, false); + TestUtils.compareFrames(reference, fromSparse, false); + } + catch(Exception e) { + e.printStackTrace(); + fail(spec + " : " + e.getMessage()); + } + } + + private static FrameBlock decode(String spec, String[] colnames, FrameBlock meta, MatrixBlock in, int k) { + final Decoder decoder = DecoderFactory.createDecoder(spec, colnames, null, meta, in.getNumColumns()); + return decoder.decode(in, new FrameBlock(decoder.getSchema()), k); + } + + private static MatrixBlock forceDense(MatrixBlock in) { + final MatrixBlock out = new MatrixBlock(); + out.copy(in); + if(out.isInSparseFormat()) + out.sparseToDense(); + return out; + } + + private static MatrixBlock forceSparse(MatrixBlock in) { + final MatrixBlock out = new MatrixBlock(); + out.copy(in); + if(!out.isInSparseFormat()) + out.denseToSparse(); + return out; + } +} From d0083b156ec164d51fb2929b54bf68b14fa84608 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 9 Jun 2026 22:15:23 +0000 Subject: [PATCH 4/4] Speed up boolean-token fallback in StringArray.getAsDouble Replace the toLowerCase plus equals chain in the parse fallback with a length-based dispatch: a single char compare for the 1-char "t"/"f" tokens and compareToIgnoreCase for "true"/"false", matching the idiom already used in DoubleArray.parseDouble. This avoids allocating a lower-cased copy and rejects non-boolean strings immediately. Restore throwing DMLRuntimeException on unparseable input. The previous re-throw of the raw NumberFormatException changed the exception type and broke callers such as Array.extractDouble that expect DMLRuntimeException; the throw path is the genuinely-exceptional case, so the wrapping cost is irrelevant there. --- .../frame/data/columns/StringArray.java | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java index 292fcb52bf5..8156a98cd35 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/StringArray.java @@ -610,14 +610,22 @@ private static double getAsDouble(String s) { return DoubleArray.parseDouble(s); } catch(Exception e) { - String ls = s.toLowerCase(); - if(ls.equals("true") || ls.equals("t")) + // Fallback for boolean-like tokens. Dispatch on length first so non-boolean strings are + // rejected immediately, and avoid allocating a lower-cased copy by comparing case-insensitively + // (single char compare for the 1-char tokens). + final int len = s.length(); + if(len == 1) { + final char c = s.charAt(0); + if(c == 't' || c == 'T') + return 1; + else if(c == 'f' || c == 'F') + return 0; + } + else if(len == 4 && s.compareToIgnoreCase("true") == 0) return 1; - else if(ls.equals("false") || ls.equals("f")) + else if(len == 5 && s.compareToIgnoreCase("false") == 0) return 0; - else - throw e; // for efficiency - // throw new DMLRuntimeException("Unable to change to double: " + s, e); + throw new DMLRuntimeException("Unable to change to double: " + s, e); } }