diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index e21c539d6d8..f5719641df7 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -154,6 +154,7 @@ public enum Builtins { GARCH("garch", true), GAUSSIAN_CLASSIFIER("gaussianClassifier", true), GET_ACCURACY("getAccuracy", true), + GET_CATEGORICAL_MASK("getCategoricalMask", false), GLM("glm", true), GLM_PREDICT("glmPredict", true), GLOVE("glove", true), diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 1b0536416d6..9a894dde13b 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -215,6 +215,8 @@ public enum Opcodes { TRANSFORMMETA("transformmeta", InstructionType.ParameterizedBuiltin), TRANSFORMENCODE("transformencode", InstructionType.MultiReturnParameterizedBuiltin, InstructionType.MultiReturnBuiltin), + GET_CATEGORICAL_MASK("get_categorical_mask", InstructionType.Binary), + //Ternary instruction opcodes PM("+*", InstructionType.Ternary), MINUSMULT("-*", InstructionType.Ternary), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index 2e3543882d2..c2832aeb8cd 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -639,6 +639,7 @@ public enum OpOp2 { MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=)) LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5) MINUS1_MULT(false), //1-X*Y + GET_CATEGORICAL_MASK(false), // get transformation mask QUANTIZE_COMPRESS(false), //quantization-fused compression UNION_DISTINCT(false); diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index 2b803a053c1..dc7edf76e50 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -853,7 +853,10 @@ else if( (op == OpOp2.CBIND && getDataType().isList()) || (op == OpOp2.RBIND && getDataType().isList())) { _etype = ExecType.CP; } - + + if( op == OpOp2.GET_CATEGORICAL_MASK) + _etype = ExecType.CP; + //mark for recompile (forever) setRequiresRecompileIfNecessary(); diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 28f6949f722..ab0c7993b4e 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -2018,6 +2018,15 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV else raiseValidateError("The compress or decompress instruction is not allowed in dml scripts"); break; + case GET_CATEGORICAL_MASK: + checkNumParameters(2); + checkFrameParam(getFirstExpr()); + checkScalarParam(getSecondExpr()); + output.setDataType(DataType.MATRIX); + output.setDimensions(1, -1); + output.setBlocksize( id.getBlocksize()); + output.setValueType(ValueType.FP64); + break; case QUANTIZE_COMPRESS: if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND) { checkNumParameters(2); @@ -2383,6 +2392,13 @@ protected void checkMatrixFrameParam(Expression e) { //always unconditional raiseValidateError("Expecting matrix or frame parameter for function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS); } } + + protected void checkFrameParam(Expression e) { + if(e.getOutput().getDataType() != DataType.FRAME) { + raiseValidateError("Expecting frame parameter for function " + getOpCode(), false, + LanguageErrorCodes.UNSUPPORTED_PARAMETERS); + } + } protected void checkMatrixScalarParam(Expression e) { //always unconditional if (e.getOutput().getDataType() != DataType.MATRIX && e.getOutput().getDataType() != DataType.SCALAR) { diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index c6e7188d7bc..e14cfd31388 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2821,6 +2821,9 @@ else if ( in.length == 2 ) DataType.MATRIX, target.getValueType(), AggOp.COUNT_DISTINCT, Direction.Col, expr); break; + case GET_CATEGORICAL_MASK: + currBuiltinOp = new BinaryOp(target.getName(), DataType.MATRIX, ValueType.FP64, OpOp2.GET_CATEGORICAL_MASK, expr, expr2); + break; default: throw new ParseException("Unsupported builtin function type: "+source.getOpCode()); } diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java index 39735be62e0..eed2c58f78c 100644 --- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java +++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java @@ -54,7 +54,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, ROWCUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST, TYPEOF, APPLY_SCHEMA, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE, - DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE, + DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE, GET_CATEGORICAL_MASK, MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE} private static final VectorSpecies SPECIES = DoubleVector.SPECIES_PREFERRED; @@ -120,6 +120,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, String2BuiltinCode.put( "_map", BuiltinCode.MAP); String2BuiltinCode.put( "valueSwap", BuiltinCode.VALUE_SWAP); String2BuiltinCode.put( "applySchema", BuiltinCode.APPLY_SCHEMA); + String2BuiltinCode.put( "get_categorical_mask", BuiltinCode.GET_CATEGORICAL_MASK); } protected Builtin(BuiltinCode bf) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java index 28b8775ebd5..86184f47be6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java @@ -59,6 +59,8 @@ else if (in1.getDataType() == DataType.TENSOR && in2.getDataType() == DataType.T return new BinaryTensorTensorCPInstruction(operator, in1, in2, out, opcode, str); else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.FRAME) return new BinaryFrameFrameCPInstruction(operator, in1, in2, out, opcode, str); + else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.SCALAR) + return new BinaryFrameScalarCPInstruction(operator, in1, in2, out, opcode, str); else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.MATRIX) return new BinaryFrameMatrixCPInstruction(operator, in1, in2, out, opcode, str); else diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java new file mode 100644 index 00000000000..193894fd9bc --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.cp; + +import java.util.Arrays; + +import org.apache.sysds.common.Builtins; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator; +import org.apache.sysds.runtime.transform.TfUtils.TfMethod; +import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.wink.json4j.JSONException; +import org.apache.wink.json4j.JSONObject; + +public class BinaryFrameScalarCPInstruction extends BinaryCPInstruction { + // private static final Log LOG = LogFactory.getLog(BinaryFrameFrameCPInstruction.class.getName()); + + private static final TfMethod[] UNSUPPORTED_MASK_METHODS = new TfMethod[] {TfMethod.BIN, + TfMethod.WORD_EMBEDDING, TfMethod.BAG_OF_WORDS, TfMethod.UDF}; + + protected BinaryFrameScalarCPInstruction(MultiThreadedOperator op, CPOperand in1, CPOperand in2, CPOperand out, + String opcode, String istr) { + super(CPType.Binary, op, in1, in2, out, opcode, istr); + } + + @Override + public void processInstruction(ExecutionContext ec) { + // get input frames + FrameBlock inBlock1 = ec.getFrameInput(input1.getName()); + ScalarObject spec = ec.getScalarInput(input2.getName(), ValueType.STRING, true); + if(getOpcode().equals(Builtins.GET_CATEGORICAL_MASK.toString().toLowerCase())) { + processGetCategorical(ec, inBlock1, spec); + } + else { + throw new DMLRuntimeException("Unsupported operation"); + } + + // Release the memory occupied by input frames + ec.releaseFrameInput(input1.getName()); + } + + private static void validate(JSONObject jSpec) { + try { + if(!jSpec.containsKey("ids") || !jSpec.getBoolean("ids")) + throw new DMLRuntimeException("not supported non ID based spec for get_categorical_mask"); + + for(TfMethod m : UNSUPPORTED_MASK_METHODS) + if(jSpec.containsKey(m.toString())) + throw new DMLRuntimeException("unsupported transform method '" + m + "' for get_categorical_mask"); + } + catch(JSONException e) { + throw new DMLRuntimeException(e); + } + } + + public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObject spec) { + try { + // 1. extract the spec, 2. validate it + JSONObject jSpec = new JSONObject(spec.getStringValue()); + validate(jSpec); + + // 3.-5. fold each supported transform method into the per-column mask state + CategoricalMask mask = new CategoricalMask(f, jSpec); + mask.hash(); + mask.recode(); + mask.dummycode(); + + // 6.-7. size and materialize the output mask + ec.setMatrixOutput(output.getName(), mask.toMatrixBlock()); + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + } + + /** + * Accumulates, per input column, how many output columns it expands to (lengths) and whether those + * output columns are categorical (categorical). The arrays are allocated lazily: a column that no + * method touches keeps the implicit default of a single, non-categorical output column. + */ + private static final class CategoricalMask { + private final FrameBlock f; + private final JSONObject jSpec; + private final int nCol; + + private int[] lengths = null; + private boolean[] categorical = null; + + // feature-hashed columns map to K buckets; a plain hashed column produces a single + // (categorical) bucket-id column, while a hashed column that is additionally dummycoded + // expands to K columns. + private boolean[] hashed = null; + private int K = 0; + + private CategoricalMask(FrameBlock f, JSONObject jSpec) { + this.f = f; + this.jSpec = jSpec; + this.nCol = f.getNumColumns(); + } + + private void hash() throws JSONException { + String hash = TfMethod.HASH.toString(); + if(!jSpec.containsKey(hash)) + return; + K = jSpec.getInt("K"); + hashed = new boolean[nCol]; + ensureCategorical(); + for(Object aa : jSpec.getJSONArray(hash)) { + int av = (Integer) aa - 1; + hashed[av] = true; + categorical[av] = true; + } + } + + private void recode() throws JSONException { + String recode = TfMethod.RECODE.toString(); + if(!jSpec.containsKey(recode)) + return; + ensureCategorical(); + for(Object aa : jSpec.getJSONArray(recode)) { + int av = (Integer) aa - 1; + categorical[av] = true; + } + } + + private void dummycode() throws JSONException { + String dummycode = TfMethod.DUMMYCODE.toString(); + if(!jSpec.containsKey(dummycode)) + return; + ensureCategorical(); + ensureLengths(); + for(Object aa : jSpec.getJSONArray(dummycode)) { + int av = (Integer) aa - 1; + lengths[av] = distinctCount(av); + categorical[av] = true; + } + } + + private int distinctCount(int av) { + if(hashed != null && hashed[av]) + // feature hashing followed by dummycoding yields K columns + return K; + ColumnMetadata d = f.getColumnMetadata()[av]; + String v = f.getString(0, av); + if(v.length() > 1 && v.charAt(0) == '¿') + return UtilFunctions.parseToInt(v.substring(1)); + return d.isDefault() ? 0 : (int) d.getNumDistinct(); + } + + private int sumLengths() { + if(lengths == null) + return nCol; + int sum = 0; + for(int i = 0; i < nCol; i++) + sum += lengths[i]; + return sum; + } + + private MatrixBlock toMatrixBlock() { + MatrixBlock ret = new MatrixBlock(1, sumLengths(), false); + ret.allocateDenseBlock(); + int off = 0; + for(int i = 0; i < nCol; i++) { + int len = (lengths == null) ? 1 : lengths[i]; + double val = (categorical != null && categorical[i]) ? 1 : 0; + for(int j = 0; j < len; j++) + ret.set(0, off++, val); + } + return ret; + } + + private void ensureCategorical() { + if(categorical == null) + categorical = new boolean[nCol]; + } + + private void ensureLengths() { + if(lengths == null) { + lengths = new int[nCol]; + Arrays.fill(lengths, 1); + } + } + } +} diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index 5ebc243dd44..86db894f8e3 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -2941,6 +2941,25 @@ public static void writeTestScalar(String file, double value) { } } + + /** + * Write scalar to file + * + * @param file File to write to + * @param value Value to write + */ + public static void writeTestScalar(String file, String value) { + try { + DataOutputStream out = new DataOutputStream(new FileOutputStream(file)); + try(PrintWriter pw = new PrintWriter(out)) { + pw.println(value); + } + } + catch(IOException e) { + fail("unable to write test scalar (" + file + "): " + e.getMessage()); + } + } + /** * Write scalar to file * diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java new file mode 100644 index 00000000000..d9c540f54c5 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame.transform; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; +import org.apache.sysds.runtime.controlprogram.caching.FrameObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction; +import org.apache.sysds.runtime.instructions.cp.BinaryFrameScalarCPInstruction; +import org.apache.sysds.runtime.instructions.cp.StringObject; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.meta.MetaDataFormat; +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * Unit tests that drive the get_categorical_mask instruction directly to exercise the defensive code + * paths (distinct-count prefix in the metadata frame, default column metadata, non id-based specs and + * the unsupported opcode guard) that the script-level transform tests cannot reach. + */ +public class GetCategoricalMaskInstructionTest { + protected static final Log LOG = LogFactory.getLog(GetCategoricalMaskInstructionTest.class.getName()); + + private static final String MASK_OPCODE = "get_categorical_mask"; + + @BeforeClass + public static void init() throws java.io.IOException { + CacheableData.initCaching("get_categorical_mask_instruction_test"); + } + + @Test + public void dummycodeReadsDistinctCountFromMetadataPrefix() { + // a metadata cell prefixed with '¿' encodes the number of distinct values inline + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"¿3"}}); + MatrixBlock res = run(meta, "{\"ids\": true, \"dummycode\": [1]}"); + + assertEquals(1, res.getNumRows()); + assertEquals(3, res.getNumColumns()); + assertArrayEquals(new double[] {1, 1, 1}, res.getDenseBlockValues(), 0.0); + } + + @Test + public void dummycodeDefaultMetadataContributesNoColumns() { + // first column is dummycoded but carries default metadata (no distinct count) -> 0 columns, + // the trailing pass-through column keeps the output non-empty + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING, ValueType.STRING}, + new String[][] {{"x", "y"}}); + MatrixBlock res = run(meta, "{\"ids\": true, \"dummycode\": [1]}"); + + assertEquals(1, res.getNumRows()); + assertEquals(1, res.getNumColumns()); + assertEquals(0.0, res.get(0, 0), 0.0); + } + + @Test + public void noMethodAllColumnsPassThrough() { + // a spec with only "ids" touches no column: every column is a single, non-categorical output + FrameBlock meta = metaWithDistinct(3, new int[] {0, 0, 0}); + MatrixBlock res = run(meta, "{\"ids\": true}"); + + assertMask(res, new double[] {0, 0, 0}); + } + + @Test + public void recodeInterleavedWithPassThrough() { + // categorical (recode, 1 col each) interleaved with continuous pass-through columns + FrameBlock meta = metaWithDistinct(5, new int[] {0, 0, 0, 0, 0}); + MatrixBlock res = run(meta, "{\"ids\": true, \"recode\": [1, 4]}"); + + assertMask(res, new double[] {1, 0, 0, 1, 0}); + } + + @Test + public void leadingPassThroughThenDummycodeOffsets() { + // the dummycode expansion must start at the correct offset after three continuous columns + FrameBlock meta = metaWithDistinct(4, new int[] {0, 0, 0, 3}); + MatrixBlock res = run(meta, "{\"ids\": true, \"dummycode\": [4]}"); + + assertMask(res, new double[] {0, 0, 0, 1, 1, 1}); + } + + @Test + public void multipleDummycodeVaryingDistinctCounts() { + // several dummycoded columns of different widths, all categorical, no pass-through + FrameBlock meta = metaWithDistinct(3, new int[] {2, 4, 1}); + MatrixBlock res = run(meta, "{\"ids\": true, \"dummycode\": [1, 2, 3]}"); + + assertMask(res, new double[] {1, 1, 1, 1, 1, 1, 1}); + } + + @Test + public void dummycodeAndPassThroughAndRecodeInterleaved() { + // dummycode(3) | pass-through | recode | dummycode(2): exercises every offset transition + FrameBlock meta = metaWithDistinct(4, new int[] {3, 0, 0, 2}); + MatrixBlock res = run(meta, "{\"ids\": true, \"recode\": [3], \"dummycode\": [1, 4]}"); + + assertMask(res, new double[] {1, 1, 1, 0, 1, 1, 1}); + } + + @Test + public void recodeAndDummycodeOnSameColumnExpands() { + // a column listed in both recode and dummycode must expand to its dummycode width, not collapse + FrameBlock meta = metaWithDistinct(2, new int[] {4, 0}); + MatrixBlock res = run(meta, "{\"ids\": true, \"recode\": [1], \"dummycode\": [1]}"); + + assertMask(res, new double[] {1, 1, 1, 1, 0}); + } + + @Test + public void hashOnlyColumnStaysSingleCategorical() { + // a hashed-but-not-dummycoded column is a single categorical column; K must not widen it + FrameBlock meta = metaWithDistinct(3, new int[] {0, 0, 0}); + MatrixBlock res = run(meta, "{\"ids\": true, \"hash\": [2], \"K\": 5}"); + + assertMask(res, new double[] {0, 1, 0}); + } + + @Test + public void hashDummycodeRecodePassThroughMixed() { + // col1: hash+dummycode -> K=3 (metadata ignored); col2: pass-through; col3: dummycode(9); + // col4: pass-through; col5: recode. Verifies hashed columns use K while plain dummycode uses + // the metadata distinct count, with correct offsets across the whole row. + FrameBlock meta = metaWithDistinct(5, new int[] {0, 0, 9, 0, 0}); + MatrixBlock res = run(meta, "{\"ids\": true, \"recode\": [5], \"dummycode\": [1, 3], \"hash\": [1], \"K\": 3}"); + + assertMask(res, new double[] {1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1}); + } + + @Test + public void nonIdSpecMissingIdsKeyThrows() { + // a spec without the "ids" key must be rejected, not silently mis-interpreted + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + assertThrowsMessage("non ID based spec", () -> run(meta, "{\"recode\": [1]}")); + } + + @Test + public void nonIdSpecIdsFalseThrows() { + // "ids": false is equally unsupported + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + assertThrowsMessage("non ID based spec", () -> run(meta, "{\"ids\": false, \"recode\": [1]}")); + } + + @Test + public void unsupportedBinMethodThrows() { + // bin expands to bin-count columns under dummycode, which the mask does not model + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + assertThrowsMessage("unsupported transform method 'bin'", + () -> run(meta, "{\"ids\": true, \"bin\": [{\"id\": 1, \"method\": \"equi-width\", \"numbins\": 3}]}")); + } + + @Test + public void unsupportedWordEmbeddingMethodThrows() { + // word_embedding maps a column to an embedding vector (many columns), not a single mask entry + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + assertThrowsMessage("unsupported transform method 'word_embedding'", + () -> run(meta, "{\"ids\": true, \"word_embedding\": [1]}")); + } + + @Test + public void unsupportedBagOfWordsMethodThrows() { + // bag_of_words expands to one column per dictionary token + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + assertThrowsMessage("unsupported transform method 'bag_of_words'", + () -> run(meta, "{\"ids\": true, \"bag_of_words\": [1]}")); + } + + @Test + public void unsupportedUdfMethodThrows() { + // udf output arity is user-defined and cannot be inferred from the spec + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + assertThrowsMessage("unsupported transform method 'udf'", + () -> run(meta, "{\"ids\": true, \"udf\": {\"name\": \"f\", \"ids\": [1]}}")); + } + + @Test + public void imputeAndOmitAreAccepted() { + // impute and omit do not change the output column count or categorical flag, so a spec that + // only adds them on top of a recoded column must still succeed and mark that column categorical + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + MatrixBlock res = run(meta, "{\"ids\": true, \"recode\": [1], \"impute\": [{\"id\": 1, \"method\": \"global_mode\"}], \"omit\": [1]}"); + + assertEquals(1, res.getNumRows()); + assertEquals(1, res.getNumColumns()); + assertEquals(1.0, res.get(0, 0), 0.0); + } + + @Test + public void malformedSpecWrapsJsonException() { + // "ids" present but not a boolean makes spec parsing throw a JSONException, which must be + // wrapped as a DMLRuntimeException rather than propagating raw + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + assertThrowsMessage("was not a boolean", () -> run(meta, "{\"ids\": 5, \"recode\": [1]}")); + } + + @Test + public void unsupportedOpcodeThrows() { + // any frame-scalar binary opcode other than get_categorical_mask must be rejected + ExecutionContext ec = ExecutionContextFactory.createContext(); + ec.setAutoCreateVars(true); + ec.setVariable("F", frameObject(new FrameBlock(new ValueType[] {ValueType.STRING}, + new String[][] {{"a"}}))); + assertThrowsMessage("Unsupported operation", () -> maskInstruction("+").processInstruction(ec)); + } + + /** Assert the action throws a DMLRuntimeException whose message chain contains the expected text. */ + private static void assertThrowsMessage(String expected, Runnable action) { + try { + action.run(); + fail("Expected DMLRuntimeException containing \"" + expected + "\" but nothing was thrown"); + } + catch(DMLRuntimeException e) { + StringBuilder chain = new StringBuilder(); + for(Throwable t = e; t != null; t = t.getCause()) + chain.append(t.getMessage()).append(" | "); + assertTrue("Exception chain [" + chain + "] should contain \"" + expected + "\"", + chain.toString().contains(expected)); + } + } + + /** Assert the mask is a single row equal to the expected values (which also fixes its width). */ + private static void assertMask(MatrixBlock res, double[] expected) { + assertEquals(1, res.getNumRows()); + assertEquals(expected.length, res.getNumColumns()); + // compare per cell rather than via getDenseBlockValues(): an all-zero mask has nnz == 0 and + // therefore no materialized dense block + double[] actual = new double[expected.length]; + for(int i = 0; i < expected.length; i++) + actual[i] = res.get(0, i); + assertArrayEquals(expected, actual, 0.0); + } + + /** + * Build a single-row metadata frame of nCol string columns. A positive distinct[i] is written to + * that column's metadata as the recode distinct count (the path real transformencode uses), while + * a zero leaves the column with default metadata (a continuous / non-dummycoded column). + */ + private static FrameBlock metaWithDistinct(int nCol, int[] distinct) { + ValueType[] schema = new ValueType[nCol]; + String[][] data = new String[1][nCol]; + for(int i = 0; i < nCol; i++) { + schema[i] = ValueType.STRING; + data[0][i] = "v"; + } + FrameBlock fb = new FrameBlock(schema, data); + for(int i = 0; i < nCol; i++) + if(distinct[i] > 0) + fb.setColumnMetadata(i, new ColumnMetadata(distinct[i])); + return fb; + } + + private static MatrixBlock run(FrameBlock meta, String spec) { + ExecutionContext ec = ExecutionContextFactory.createContext(); + ec.setAutoCreateVars(true); + maskInstruction(MASK_OPCODE).processGetCategorical(ec, meta, new StringObject(spec)); + return ec.getMatrixObject("out").acquireReadAndRelease(); + } + + private static BinaryFrameScalarCPInstruction maskInstruction(String opcode) { + String in1 = InstructionUtils.concatOperandParts("F", DataType.FRAME.name(), ValueType.STRING.name(), "false"); + String in2 = InstructionUtils.concatOperandParts("spec", DataType.SCALAR.name(), ValueType.STRING.name(), "true"); + String out = InstructionUtils.concatOperandParts("out", DataType.MATRIX.name(), ValueType.FP64.name(), "false"); + String str = InstructionUtils.concatOperands("CP", opcode, in1, in2, out); + return (BinaryFrameScalarCPInstruction) BinaryCPInstruction.parseInstruction(str); + } + + private static FrameObject frameObject(FrameBlock fb) { + MatrixCharacteristics mc = new MatrixCharacteristics(fb.getNumRows(), fb.getNumColumns(), -1, -1); + FrameObject fo = new FrameObject("F", new MetaDataFormat(mc, FileFormat.BINARY), fb.getSchema()); + fo.acquireModify(fb); + fo.release(); + return fo; + } +} diff --git a/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java b/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java new file mode 100644 index 00000000000..30681f373e4 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.transform; + +import static org.junit.Assert.fail; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class GetCategoricalMaskTest extends AutomatedTestBase { + protected static final Log LOG = LogFactory.getLog(GetCategoricalMaskTest.class.getName()); + + private final static String TEST_NAME1 = "GetCategoricalMaskTest"; + private final static String TEST_DIR = "functions/transform/"; + private final static String TEST_CLASS_DIR = TEST_DIR + TransformFrameEncodeApplyTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"y"})); + } + + @Test + public void testRecode() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(10, new ValueType[] {ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 1, 1.0); + String spec = "{\"ids\": true, \"recode\": [1]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testRecode2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(10, new ValueType[] {ValueType.UINT8, ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 2, new double[] {0, 1}); + + String spec = "{\"ids\": true, \"recode\": [2]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testDummy1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 6, new double[] {0, 1, 1, 1, 1, 1}); + + String spec = "{\"ids\": true, \"dummycode\": [2]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testDummy2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 6, new double[] {1, 1, 1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 4, new double[] {1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1], \"hash\": [1], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 4, new double[] {1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1], \"hash\": [1], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash3() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.INT64,ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 7, new double[] {1, 1, 1, 0, 1, 1, 1}); + + String spec = "{\"ids\": true, \"dummycode\": [1,3], \"hash\": [1,3], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + + @Test + public void testHybrid1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.INT64,ValueType.UINT8, ValueType.BOOLEAN}, 32); + MatrixBlock expected = new MatrixBlock(1, 9, new double[] {1, 1, 1, 0, 1, 1, 1,1,1}); + + String spec = "{\"ids\": true, \"dummycode\": [1,3,4], \"hash\": [1,3], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHybrid2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.BOOLEAN,ValueType.UINT8, ValueType.BOOLEAN}, 32); + MatrixBlock expected = new MatrixBlock(1, 10, new double[] {1, 1, 1, 1,1, 1, 1, 1,1,1}); + + String spec = "{\"ids\": true, \"dummycode\": [1,2,3,4], \"hash\": [1,3], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + private void runTransformTest(FrameBlock fb, String spec, MatrixBlock expected) throws Exception { + try { + + getAndLoadTestConfiguration(TEST_NAME1); + + String inF = input("F-In"); + String inS = input("spec"); + + TestUtils.writeTestFrame(inF, fb, fb.getSchema(), FileFormat.CSV); + TestUtils.writeTestScalar(input("spec"), spec); + + String out = output("ret"); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-args", inF, inS, out, expected.getNumColumns() + ""}; + + runTest(true, false, null, -1); + + MatrixBlock result = TestUtils.readBinary(out); + + TestUtils.compareMatrices(expected, result, 0.0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + +} diff --git a/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml b/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml new file mode 100644 index 00000000000..5d7bb35a250 --- /dev/null +++ b/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +F1 = read($1, data_type="frame", format="csv"); + +jspec = read($2, data_type="scalar", value_type="string"); + +[X, M] = transformencode(target=F1, spec=jspec); + +Cm = getCategoricalMask(M, jspec) +expectedColumns = $4 +if(ncol(Cm) != expectedColumns){ + stop("Wrong number of metadata columns in categorical mask") +} +# print mean to verify that Cm is a matrix, not a Frame according to compiler +print(mean(Cm)) + +write(Cm, $3, format="csv"); +