From a2165d73f45b14cc30b71f7f41ec2364493393e5 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Fri, 29 May 2026 13:52:45 +0000 Subject: [PATCH 1/5] Add getCategoricalMask DML builtin Adds a new builtin that, given a transform-encode metadata frame and the encoding JSON spec, returns a 1xN matrix mask marking which output columns are categorical (1) versus continuous (0). Useful when callers need to know the category boundary in transformed output without re-deriving it from the spec. - Register GET_CATEGORICAL_MASK in Builtins, Opcodes, Types (OpOp2), Builtin (functionobject) - Validate it as a frame+scalar binary in BuiltinFunctionExpression (new checkFrameParam helper) and lower it to a BinaryOp in DMLTranslator - Force CP execution for the new op in BinaryOp.optFindExecType - Implement runtime in BinaryFrameScalarCPInstruction and route FRAME+SCALAR binary instructions to it in BinaryCPInstruction - Add writeTestScalar(String, String) overload to TestUtils - Cover recode, dummycode, hash, and hybrid specs in GetCategoricalMaskTest (note: hash variants depend on the decoder/encoder hash-column changes in a separate branch) --- .../org/apache/sysds/common/Builtins.java | 1 + .../java/org/apache/sysds/common/Opcodes.java | 2 + .../java/org/apache/sysds/common/Types.java | 1 + .../java/org/apache/sysds/hops/BinaryOp.java | 5 +- .../parser/BuiltinFunctionExpression.java | 16 ++ .../apache/sysds/parser/DMLTranslator.java | 3 + .../runtime/functionobjects/Builtin.java | 3 +- .../instructions/cp/BinaryCPInstruction.java | 2 + .../cp/BinaryFrameScalarCPInstruction.java | 130 ++++++++++++++ .../java/org/apache/sysds/test/TestUtils.java | 20 +++ .../transform/GetCategoricalMaskTest.java | 167 ++++++++++++++++++ .../transform/GetCategoricalMaskTest.dml | 37 ++++ 12 files changed, 385 insertions(+), 2 deletions(-) create mode 100644 src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java create mode 100644 src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java create mode 100644 src/test/scripts/functions/transform/GetCategoricalMaskTest.dml diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index e21c539d6d8..f5719641df7 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -154,6 +154,7 @@ public enum Builtins { GARCH("garch", true), GAUSSIAN_CLASSIFIER("gaussianClassifier", true), GET_ACCURACY("getAccuracy", true), + GET_CATEGORICAL_MASK("getCategoricalMask", false), GLM("glm", true), GLM_PREDICT("glmPredict", true), GLOVE("glove", true), diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java b/src/main/java/org/apache/sysds/common/Opcodes.java index 1b0536416d6..9a894dde13b 100644 --- a/src/main/java/org/apache/sysds/common/Opcodes.java +++ b/src/main/java/org/apache/sysds/common/Opcodes.java @@ -215,6 +215,8 @@ public enum Opcodes { TRANSFORMMETA("transformmeta", InstructionType.ParameterizedBuiltin), TRANSFORMENCODE("transformencode", InstructionType.MultiReturnParameterizedBuiltin, InstructionType.MultiReturnBuiltin), + GET_CATEGORICAL_MASK("get_categorical_mask", InstructionType.Binary), + //Ternary instruction opcodes PM("+*", InstructionType.Ternary), MINUSMULT("-*", InstructionType.Ternary), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index 2e3543882d2..c2832aeb8cd 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -639,6 +639,7 @@ public enum OpOp2 { MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=)) LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5) MINUS1_MULT(false), //1-X*Y + GET_CATEGORICAL_MASK(false), // get transformation mask QUANTIZE_COMPRESS(false), //quantization-fused compression UNION_DISTINCT(false); diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index 2b803a053c1..dc7edf76e50 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -853,7 +853,10 @@ else if( (op == OpOp2.CBIND && getDataType().isList()) || (op == OpOp2.RBIND && getDataType().isList())) { _etype = ExecType.CP; } - + + if( op == OpOp2.GET_CATEGORICAL_MASK) + _etype = ExecType.CP; + //mark for recompile (forever) setRequiresRecompileIfNecessary(); diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index 28f6949f722..ab0c7993b4e 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -2018,6 +2018,15 @@ else if(this.getOpCode() == Builtins.MAX_POOL || this.getOpCode() == Builtins.AV else raiseValidateError("The compress or decompress instruction is not allowed in dml scripts"); break; + case GET_CATEGORICAL_MASK: + checkNumParameters(2); + checkFrameParam(getFirstExpr()); + checkScalarParam(getSecondExpr()); + output.setDataType(DataType.MATRIX); + output.setDimensions(1, -1); + output.setBlocksize( id.getBlocksize()); + output.setValueType(ValueType.FP64); + break; case QUANTIZE_COMPRESS: if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_QUANTIZE_COMPRESS_COMMAND) { checkNumParameters(2); @@ -2383,6 +2392,13 @@ protected void checkMatrixFrameParam(Expression e) { //always unconditional raiseValidateError("Expecting matrix or frame parameter for function "+ getOpCode(), false, LanguageErrorCodes.UNSUPPORTED_PARAMETERS); } } + + protected void checkFrameParam(Expression e) { + if(e.getOutput().getDataType() != DataType.FRAME) { + raiseValidateError("Expecting frame parameter for function " + getOpCode(), false, + LanguageErrorCodes.UNSUPPORTED_PARAMETERS); + } + } protected void checkMatrixScalarParam(Expression e) { //always unconditional if (e.getOutput().getDataType() != DataType.MATRIX && e.getOutput().getDataType() != DataType.SCALAR) { diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index c6e7188d7bc..e14cfd31388 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2821,6 +2821,9 @@ else if ( in.length == 2 ) DataType.MATRIX, target.getValueType(), AggOp.COUNT_DISTINCT, Direction.Col, expr); break; + case GET_CATEGORICAL_MASK: + currBuiltinOp = new BinaryOp(target.getName(), DataType.MATRIX, ValueType.FP64, OpOp2.GET_CATEGORICAL_MASK, expr, expr2); + break; default: throw new ParseException("Unsupported builtin function type: "+source.getOpCode()); } diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java index 39735be62e0..eed2c58f78c 100644 --- a/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java +++ b/src/main/java/org/apache/sysds/runtime/functionobjects/Builtin.java @@ -54,7 +54,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, MAX, ABS, SIGN, SQRT, EXP, PLOGP, PRINT, PRINTF, NROW, NCOL, LENGTH, LINEAGE, ROUND, MAXINDEX, MININDEX, STOP, CEIL, FLOOR, CUMSUM, ROWCUMSUM, CUMPROD, CUMMIN, CUMMAX, CUMSUMPROD, INVERSE, SPROP, SIGMOID, EVAL, LIST, TYPEOF, APPLY_SCHEMA, DETECTSCHEMA, ISNA, ISNAN, ISINF, DROP_INVALID_TYPE, - DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE, + DROP_INVALID_LENGTH, VALUE_SWAP, FRAME_ROW_REPLICATE, GET_CATEGORICAL_MASK, MAP, COUNT_DISTINCT, COUNT_DISTINCT_APPROX, UNIQUE} private static final VectorSpecies SPECIES = DoubleVector.SPECIES_PREFERRED; @@ -120,6 +120,7 @@ public enum BuiltinCode { AUTODIFF, SIN, COS, TAN, SINH, COSH, TANH, ASIN, ACOS, String2BuiltinCode.put( "_map", BuiltinCode.MAP); String2BuiltinCode.put( "valueSwap", BuiltinCode.VALUE_SWAP); String2BuiltinCode.put( "applySchema", BuiltinCode.APPLY_SCHEMA); + String2BuiltinCode.put( "get_categorical_mask", BuiltinCode.GET_CATEGORICAL_MASK); } protected Builtin(BuiltinCode bf) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java index 28b8775ebd5..86184f47be6 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryCPInstruction.java @@ -59,6 +59,8 @@ else if (in1.getDataType() == DataType.TENSOR && in2.getDataType() == DataType.T return new BinaryTensorTensorCPInstruction(operator, in1, in2, out, opcode, str); else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.FRAME) return new BinaryFrameFrameCPInstruction(operator, in1, in2, out, opcode, str); + else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.SCALAR) + return new BinaryFrameScalarCPInstruction(operator, in1, in2, out, opcode, str); else if (in1.getDataType() == DataType.FRAME && in2.getDataType() == DataType.MATRIX) return new BinaryFrameMatrixCPInstruction(operator, in1, in2, out, opcode, str); else diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java new file mode 100644 index 00000000000..99b3c1a3b13 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.instructions.cp; + +import java.util.Arrays; + +import org.apache.sysds.common.Builtins; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator; +import org.apache.sysds.runtime.transform.TfUtils.TfMethod; +import org.apache.sysds.runtime.util.UtilFunctions; +import org.apache.wink.json4j.JSONArray; +import org.apache.wink.json4j.JSONObject; + +public class BinaryFrameScalarCPInstruction extends BinaryCPInstruction { + // private static final Log LOG = LogFactory.getLog(BinaryFrameFrameCPInstruction.class.getName()); + + protected BinaryFrameScalarCPInstruction(MultiThreadedOperator op, CPOperand in1, CPOperand in2, CPOperand out, + String opcode, String istr) { + super(CPType.Binary, op, in1, in2, out, opcode, istr); + } + + @Override + public void processInstruction(ExecutionContext ec) { + // get input frames + FrameBlock inBlock1 = ec.getFrameInput(input1.getName()); + ScalarObject spec = ec.getScalarInput(input2.getName(), ValueType.STRING, true); + if(getOpcode().equals(Builtins.GET_CATEGORICAL_MASK.toString().toLowerCase())) { + processGetCategorical(ec, inBlock1, spec); + } + else { + throw new DMLRuntimeException("Unsupported operation"); + } + + // Release the memory occupied by input frames + ec.releaseFrameInput(input1.getName()); + } + + public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObject spec) { + try { + + // MatrixBlock ret = new MatrixBlock(); + int nCol = f.getNumColumns(); + + JSONObject jSpec = new JSONObject(spec.getStringValue()); + + if(!jSpec.containsKey("ids") && jSpec.getBoolean("ids")) { + throw new DMLRuntimeException("not supported non ID based spec for get_categorical_mask"); + } + + String recode = TfMethod.RECODE.toString(); + String dummycode = TfMethod.DUMMYCODE.toString(); + + int[] lengths = new int[nCol]; + // assume all columns encode to at least one column. + Arrays.fill(lengths, 1); + boolean[] categorical = new boolean[nCol]; + + if(jSpec.containsKey(recode)) { + JSONArray a = jSpec.getJSONArray(recode); + for(Object aa : a) { + int av = (Integer) aa - 1; + categorical[av] = true; + } + } + + if(jSpec.containsKey(dummycode)) { + JSONArray a = jSpec.getJSONArray(dummycode); + for(Object aa : a) { + int av = (Integer) aa - 1; + ColumnMetadata d = f.getColumnMetadata()[av]; + String v = f.getString(0, av); + int ndist; + if(v.length() > 1 && v.charAt(0) == '¿') { + ndist = UtilFunctions.parseToInt(v.substring(1)); + } + else { + ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); + } + lengths[av] = ndist; + categorical[av] = true; + } + } + + // get total size after mapping + + int sumLengths = 0; + for(int i : lengths) { + sumLengths += i; + } + + MatrixBlock ret = new MatrixBlock(1, sumLengths, false); + ret.allocateDenseBlock(); + int off = 0; + for(int i = 0; i < lengths.length; i++) { + for(int j = 0; j < lengths[i]; j++) { + ret.set(0, off++, categorical[i] ? 1 : 0); + } + } + + ec.setMatrixOutput(output.getName(), ret); + + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index 5ebc243dd44..8cc82ec82b0 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -32,6 +32,7 @@ import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.FileReader; +import java.io.FileWriter; import java.io.IOException; import java.io.InputStreamReader; import java.io.OutputStreamWriter; @@ -2941,6 +2942,25 @@ public static void writeTestScalar(String file, double value) { } } + + /** + * Write scalar to file + * + * @param file File to write to + * @param value Value to write + */ + public static void writeTestScalar(String file, String value) { + try { + DataOutputStream out = new DataOutputStream(new FileOutputStream(file)); + try(PrintWriter pw = new PrintWriter(out)) { + pw.println(value); + } + } + catch(IOException e) { + fail("unable to write test scalar (" + file + "): " + e.getMessage()); + } + } + /** * Write scalar to file * diff --git a/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java b/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java new file mode 100644 index 00000000000..30681f373e4 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/functions/transform/GetCategoricalMaskTest.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.functions.transform; + +import static org.junit.Assert.fail; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; + +public class GetCategoricalMaskTest extends AutomatedTestBase { + protected static final Log LOG = LogFactory.getLog(GetCategoricalMaskTest.class.getName()); + + private final static String TEST_NAME1 = "GetCategoricalMaskTest"; + private final static String TEST_DIR = "functions/transform/"; + private final static String TEST_CLASS_DIR = TEST_DIR + TransformFrameEncodeApplyTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + TestUtils.clearAssertionInformation(); + addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"y"})); + } + + @Test + public void testRecode() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(10, new ValueType[] {ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 1, 1.0); + String spec = "{\"ids\": true, \"recode\": [1]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testRecode2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(10, new ValueType[] {ValueType.UINT8, ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 2, new double[] {0, 1}); + + String spec = "{\"ids\": true, \"recode\": [2]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testDummy1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 6, new double[] {0, 1, 1, 1, 1, 1}); + + String spec = "{\"ids\": true, \"dummycode\": [2]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testDummy2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 6, new double[] {1, 1, 1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1]}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(5, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 4, new double[] {1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1], \"hash\": [1], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.INT64}, 32); + MatrixBlock expected = new MatrixBlock(1, 4, new double[] {1, 1, 1, 0}); + + String spec = "{\"ids\": true, \"dummycode\": [1], \"hash\": [1], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHash3() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.INT64,ValueType.UINT8}, 32); + MatrixBlock expected = new MatrixBlock(1, 7, new double[] {1, 1, 1, 0, 1, 1, 1}); + + String spec = "{\"ids\": true, \"dummycode\": [1,3], \"hash\": [1,3], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + + @Test + public void testHybrid1() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.INT64,ValueType.UINT8, ValueType.BOOLEAN}, 32); + MatrixBlock expected = new MatrixBlock(1, 9, new double[] {1, 1, 1, 0, 1, 1, 1,1,1}); + + String spec = "{\"ids\": true, \"dummycode\": [1,3,4], \"hash\": [1,3], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + @Test + public void testHybrid2() throws Exception { + FrameBlock fb = TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT8, ValueType.BOOLEAN,ValueType.UINT8, ValueType.BOOLEAN}, 32); + MatrixBlock expected = new MatrixBlock(1, 10, new double[] {1, 1, 1, 1,1, 1, 1, 1,1,1}); + + String spec = "{\"ids\": true, \"dummycode\": [1,2,3,4], \"hash\": [1,3], \"K\": 3}"; + runTransformTest(fb, spec, expected); + + } + + private void runTransformTest(FrameBlock fb, String spec, MatrixBlock expected) throws Exception { + try { + + getAndLoadTestConfiguration(TEST_NAME1); + + String inF = input("F-In"); + String inS = input("spec"); + + TestUtils.writeTestFrame(inF, fb, fb.getSchema(), FileFormat.CSV); + TestUtils.writeTestScalar(input("spec"), spec); + + String out = output("ret"); + + String HOME = SCRIPT_DIR + TEST_DIR; + fullDMLScriptName = HOME + TEST_NAME1 + ".dml"; + programArgs = new String[] {"-args", inF, inS, out, expected.getNumColumns() + ""}; + + runTest(true, false, null, -1); + + MatrixBlock result = TestUtils.readBinary(out); + + TestUtils.compareMatrices(expected, result, 0.0); + } + catch(Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + +} diff --git a/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml b/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml new file mode 100644 index 00000000000..5d7bb35a250 --- /dev/null +++ b/src/test/scripts/functions/transform/GetCategoricalMaskTest.dml @@ -0,0 +1,37 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +F1 = read($1, data_type="frame", format="csv"); + +jspec = read($2, data_type="scalar", value_type="string"); + +[X, M] = transformencode(target=F1, spec=jspec); + +Cm = getCategoricalMask(M, jspec) +expectedColumns = $4 +if(ncol(Cm) != expectedColumns){ + stop("Wrong number of metadata columns in categorical mask") +} +# print mean to verify that Cm is a matrix, not a Frame according to compiler +print(mean(Cm)) + +write(Cm, $3, format="csv"); + From c528fa2f3a31c04b0a54dc9df2da2d67c95c4f89 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Mon, 8 Jun 2026 15:25:53 +0000 Subject: [PATCH 2/5] Handle feature-hash columns in getCategoricalMask and fix unused import getCategoricalMask only accounted for recode and dummycode columns, so specs using feature hashing produced a mask with the wrong number of columns and the DML check failed. Parse the hash column list and bucket count K from the spec: a hashed column is categorical, and a hashed column that is also dummycoded expands to K columns rather than the recode distinct count. Also correct the ID-based spec guard (it never triggered) to actually require an ID-based spec, and remove the now-unused java.io.FileWriter import in TestUtils that broke Checkstyle. --- .../cp/BinaryFrameScalarCPInstruction.java | 34 +++++++++++++++---- .../java/org/apache/sysds/test/TestUtils.java | 1 - 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java index 99b3c1a3b13..bbf4774ed7a 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java @@ -66,18 +66,34 @@ public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObjec JSONObject jSpec = new JSONObject(spec.getStringValue()); - if(!jSpec.containsKey("ids") && jSpec.getBoolean("ids")) { + if(!jSpec.containsKey("ids") || !jSpec.getBoolean("ids")) { throw new DMLRuntimeException("not supported non ID based spec for get_categorical_mask"); } String recode = TfMethod.RECODE.toString(); String dummycode = TfMethod.DUMMYCODE.toString(); + String hash = TfMethod.HASH.toString(); int[] lengths = new int[nCol]; // assume all columns encode to at least one column. Arrays.fill(lengths, 1); boolean[] categorical = new boolean[nCol]; + // feature-hashed columns map to K buckets; a plain hashed column + // produces a single (categorical) bucket-id column, while a hashed + // column that is additionally dummycoded expands to K columns. + boolean[] hashed = new boolean[nCol]; + int K = 0; + if(jSpec.containsKey(hash)) { + K = jSpec.getInt("K"); + JSONArray a = jSpec.getJSONArray(hash); + for(Object aa : a) { + int av = (Integer) aa - 1; + hashed[av] = true; + categorical[av] = true; + } + } + if(jSpec.containsKey(recode)) { JSONArray a = jSpec.getJSONArray(recode); for(Object aa : a) { @@ -90,14 +106,20 @@ public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObjec JSONArray a = jSpec.getJSONArray(dummycode); for(Object aa : a) { int av = (Integer) aa - 1; - ColumnMetadata d = f.getColumnMetadata()[av]; - String v = f.getString(0, av); int ndist; - if(v.length() > 1 && v.charAt(0) == '¿') { - ndist = UtilFunctions.parseToInt(v.substring(1)); + if(hashed[av]) { + // feature hashing followed by dummycoding yields K columns + ndist = K; } else { - ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); + ColumnMetadata d = f.getColumnMetadata()[av]; + String v = f.getString(0, av); + if(v.length() > 1 && v.charAt(0) == '¿') { + ndist = UtilFunctions.parseToInt(v.substring(1)); + } + else { + ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); + } } lengths[av] = ndist; categorical[av] = true; diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java b/src/test/java/org/apache/sysds/test/TestUtils.java index 8cc82ec82b0..86db894f8e3 100644 --- a/src/test/java/org/apache/sysds/test/TestUtils.java +++ b/src/test/java/org/apache/sysds/test/TestUtils.java @@ -32,7 +32,6 @@ import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.FileReader; -import java.io.FileWriter; import java.io.IOException; import java.io.InputStreamReader; import java.io.OutputStreamWriter; From 6aef55a5efc5e517e2634f3f789b96bb68a62b94 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Tue, 9 Jun 2026 13:39:50 +0000 Subject: [PATCH 3/5] Add unit tests for getCategoricalMask defensive code paths Drive the get_categorical_mask instruction directly to cover branches the script-level transform test cannot reach: the inline distinct-count prefix in metadata cells, default column metadata yielding zero columns, non id-based spec rejection, and the unsupported opcode guard. The error cases assert the specific exception message so they verify the intended failure rather than any wrapped exception. --- .../GetCategoricalMaskInstructionTest.java | 148 ++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java new file mode 100644 index 00000000000..b8df642542f --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame.transform; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.DataType; +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.caching.CacheableData; +import org.apache.sysds.runtime.controlprogram.caching.FrameObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.instructions.InstructionUtils; +import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction; +import org.apache.sysds.runtime.instructions.cp.BinaryFrameScalarCPInstruction; +import org.apache.sysds.runtime.instructions.cp.StringObject; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MatrixCharacteristics; +import org.apache.sysds.runtime.meta.MetaDataFormat; +import org.junit.BeforeClass; +import org.junit.Test; + +/** + * Unit tests that drive the get_categorical_mask instruction directly to exercise the defensive code + * paths (distinct-count prefix in the metadata frame, default column metadata, non id-based specs and + * the unsupported opcode guard) that the script-level transform tests cannot reach. + */ +public class GetCategoricalMaskInstructionTest { + protected static final Log LOG = LogFactory.getLog(GetCategoricalMaskInstructionTest.class.getName()); + + private static final String MASK_OPCODE = "get_categorical_mask"; + + @BeforeClass + public static void init() throws java.io.IOException { + CacheableData.initCaching("get_categorical_mask_instruction_test"); + } + + @Test + public void dummycodeReadsDistinctCountFromMetadataPrefix() { + // a metadata cell prefixed with '¿' encodes the number of distinct values inline + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"¿3"}}); + MatrixBlock res = run(meta, "{\"ids\": true, \"dummycode\": [1]}"); + + assertEquals(1, res.getNumRows()); + assertEquals(3, res.getNumColumns()); + assertArrayEquals(new double[] {1, 1, 1}, res.getDenseBlockValues(), 0.0); + } + + @Test + public void dummycodeDefaultMetadataContributesNoColumns() { + // first column is dummycoded but carries default metadata (no distinct count) -> 0 columns, + // the trailing pass-through column keeps the output non-empty + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING, ValueType.STRING}, + new String[][] {{"x", "y"}}); + MatrixBlock res = run(meta, "{\"ids\": true, \"dummycode\": [1]}"); + + assertEquals(1, res.getNumRows()); + assertEquals(1, res.getNumColumns()); + assertEquals(0.0, res.get(0, 0), 0.0); + } + + @Test + public void nonIdSpecMissingIdsKeyThrows() { + // a spec without the "ids" key must be rejected, not silently mis-interpreted + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + assertThrowsMessage("non ID based spec", () -> run(meta, "{\"recode\": [1]}")); + } + + @Test + public void nonIdSpecIdsFalseThrows() { + // "ids": false is equally unsupported + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + assertThrowsMessage("non ID based spec", () -> run(meta, "{\"ids\": false, \"recode\": [1]}")); + } + + @Test + public void unsupportedOpcodeThrows() { + // any frame-scalar binary opcode other than get_categorical_mask must be rejected + ExecutionContext ec = ExecutionContextFactory.createContext(); + ec.setAutoCreateVars(true); + ec.setVariable("F", frameObject(new FrameBlock(new ValueType[] {ValueType.STRING}, + new String[][] {{"a"}}))); + assertThrowsMessage("Unsupported operation", () -> maskInstruction("+").processInstruction(ec)); + } + + /** Assert the action throws a DMLRuntimeException whose message chain contains the expected text. */ + private static void assertThrowsMessage(String expected, Runnable action) { + try { + action.run(); + fail("Expected DMLRuntimeException containing \"" + expected + "\" but nothing was thrown"); + } + catch(DMLRuntimeException e) { + StringBuilder chain = new StringBuilder(); + for(Throwable t = e; t != null; t = t.getCause()) + chain.append(t.getMessage()).append(" | "); + assertTrue("Exception chain [" + chain + "] should contain \"" + expected + "\"", + chain.toString().contains(expected)); + } + } + + private static MatrixBlock run(FrameBlock meta, String spec) { + ExecutionContext ec = ExecutionContextFactory.createContext(); + ec.setAutoCreateVars(true); + maskInstruction(MASK_OPCODE).processGetCategorical(ec, meta, new StringObject(spec)); + return ec.getMatrixObject("out").acquireReadAndRelease(); + } + + private static BinaryFrameScalarCPInstruction maskInstruction(String opcode) { + String in1 = InstructionUtils.concatOperandParts("F", DataType.FRAME.name(), ValueType.STRING.name(), "false"); + String in2 = InstructionUtils.concatOperandParts("spec", DataType.SCALAR.name(), ValueType.STRING.name(), "true"); + String out = InstructionUtils.concatOperandParts("out", DataType.MATRIX.name(), ValueType.FP64.name(), "false"); + String str = InstructionUtils.concatOperands("CP", opcode, in1, in2, out); + return (BinaryFrameScalarCPInstruction) BinaryCPInstruction.parseInstruction(str); + } + + private static FrameObject frameObject(FrameBlock fb) { + MatrixCharacteristics mc = new MatrixCharacteristics(fb.getNumRows(), fb.getNumColumns(), -1, -1); + FrameObject fo = new FrameObject("F", new MetaDataFormat(mc, FileFormat.BINARY), fb.getSchema()); + fo.acquireModify(fb); + fo.release(); + return fo; + } +} From 46338ba4296ebb03640b90103a94edf74e6764a7 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Thu, 11 Jun 2026 16:51:36 +0000 Subject: [PATCH 4/5] Reject unsupported transform methods in getCategoricalMask getCategoricalMask only models recode/dummycode/hash column expansion. Specs using bin, word_embedding, bag_of_words, or udf would yield a mask with the wrong number of columns, so reject them explicitly with a clear error instead of returning a silently incorrect result. impute and omit remain accepted since they do not change output arity or the categorical flag. Add instruction-level tests covering each rejected method and the accepted impute/omit case. --- .../cp/BinaryFrameScalarCPInstruction.java | 13 ++++++ .../GetCategoricalMaskInstructionTest.java | 44 +++++++++++++++++++ 2 files changed, 57 insertions(+) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java index bbf4774ed7a..d3ca3dc0a31 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java @@ -70,6 +70,19 @@ public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObjec throw new DMLRuntimeException("not supported non ID based spec for get_categorical_mask"); } + // get_categorical_mask only models the column expansion of recode/dummycode/hash. + // Methods that change the output arity (bin expands under dummycode, word_embedding and + // bag_of_words map to many columns) or are user-defined (udf) would produce a mask with + // the wrong number of columns, so reject them explicitly instead of emitting a silently + // incorrect result. impute and omit are intentionally allowed: they do not alter the + // output column count or the categorical flag of a column. + for(TfMethod m : new TfMethod[] {TfMethod.BIN, TfMethod.WORD_EMBEDDING, TfMethod.BAG_OF_WORDS, + TfMethod.UDF}) { + if(jSpec.containsKey(m.toString())) + throw new DMLRuntimeException( + "unsupported transform method '" + m + "' for get_categorical_mask"); + } + String recode = TfMethod.RECODE.toString(); String dummycode = TfMethod.DUMMYCODE.toString(); String hash = TfMethod.HASH.toString(); diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java index b8df642542f..b8cc9151fd5 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java @@ -98,6 +98,50 @@ public void nonIdSpecIdsFalseThrows() { assertThrowsMessage("non ID based spec", () -> run(meta, "{\"ids\": false, \"recode\": [1]}")); } + @Test + public void unsupportedBinMethodThrows() { + // bin expands to bin-count columns under dummycode, which the mask does not model + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + assertThrowsMessage("unsupported transform method 'bin'", + () -> run(meta, "{\"ids\": true, \"bin\": [{\"id\": 1, \"method\": \"equi-width\", \"numbins\": 3}]}")); + } + + @Test + public void unsupportedWordEmbeddingMethodThrows() { + // word_embedding maps a column to an embedding vector (many columns), not a single mask entry + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + assertThrowsMessage("unsupported transform method 'word_embedding'", + () -> run(meta, "{\"ids\": true, \"word_embedding\": [1]}")); + } + + @Test + public void unsupportedBagOfWordsMethodThrows() { + // bag_of_words expands to one column per dictionary token + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + assertThrowsMessage("unsupported transform method 'bag_of_words'", + () -> run(meta, "{\"ids\": true, \"bag_of_words\": [1]}")); + } + + @Test + public void unsupportedUdfMethodThrows() { + // udf output arity is user-defined and cannot be inferred from the spec + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + assertThrowsMessage("unsupported transform method 'udf'", + () -> run(meta, "{\"ids\": true, \"udf\": {\"name\": \"f\", \"ids\": [1]}}")); + } + + @Test + public void imputeAndOmitAreAccepted() { + // impute and omit do not change the output column count or categorical flag, so a spec that + // only adds them on top of a recoded column must still succeed and mark that column categorical + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + MatrixBlock res = run(meta, "{\"ids\": true, \"recode\": [1], \"impute\": [{\"id\": 1, \"method\": \"global_mode\"}], \"omit\": [1]}"); + + assertEquals(1, res.getNumRows()); + assertEquals(1, res.getNumColumns()); + assertEquals(1.0, res.get(0, 0), 0.0); + } + @Test public void unsupportedOpcodeThrows() { // any frame-scalar binary opcode other than get_categorical_mask must be rejected From 75dd026e748c9cd11f9d84e783081f90ea50a635 Mon Sep 17 00:00:00 2001 From: Sebastian Baunsgaard Date: Thu, 11 Jun 2026 18:20:51 +0000 Subject: [PATCH 5/5] Refactor getCategoricalMask into staged builder and expand tests Extract the spec validation into a validate helper and move the per-column mask accumulation into a CategoricalMask inner class with one method per stage (hash, recode, dummycode, sumLengths, toMatrixBlock). Allocate the lengths/categorical/hashed arrays lazily so a column no method touches keeps the default of a single non-categorical output column, and short-circuit sumLengths to nCol when no expansion occurred. Add instruction-level tests covering output-column offset mapping across interleaved encodings: pass-through-only specs, recode/dummycode interleaved with continuous columns, varying dummycode widths, a column listed in both recode and dummycode, hash-only columns, and a mixed hash/dummycode/recode row. Also cover the JSONException wrapping path in validate. --- .../cp/BinaryFrameScalarCPInstruction.java | 205 +++++++++++------- .../GetCategoricalMaskInstructionTest.java | 114 ++++++++++ 2 files changed, 237 insertions(+), 82 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java index d3ca3dc0a31..193894fd9bc 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/BinaryFrameScalarCPInstruction.java @@ -31,12 +31,15 @@ import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator; import org.apache.sysds.runtime.transform.TfUtils.TfMethod; import org.apache.sysds.runtime.util.UtilFunctions; -import org.apache.wink.json4j.JSONArray; +import org.apache.wink.json4j.JSONException; import org.apache.wink.json4j.JSONObject; public class BinaryFrameScalarCPInstruction extends BinaryCPInstruction { // private static final Log LOG = LogFactory.getLog(BinaryFrameFrameCPInstruction.class.getName()); + private static final TfMethod[] UNSUPPORTED_MASK_METHODS = new TfMethod[] {TfMethod.BIN, + TfMethod.WORD_EMBEDDING, TfMethod.BAG_OF_WORDS, TfMethod.UDF}; + protected BinaryFrameScalarCPInstruction(MultiThreadedOperator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) { super(CPType.Binary, op, in1, in2, out, opcode, istr); @@ -58,108 +61,146 @@ public void processInstruction(ExecutionContext ec) { ec.releaseFrameInput(input1.getName()); } - public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObject spec) { + private static void validate(JSONObject jSpec) { try { + if(!jSpec.containsKey("ids") || !jSpec.getBoolean("ids")) + throw new DMLRuntimeException("not supported non ID based spec for get_categorical_mask"); - // MatrixBlock ret = new MatrixBlock(); - int nCol = f.getNumColumns(); + for(TfMethod m : UNSUPPORTED_MASK_METHODS) + if(jSpec.containsKey(m.toString())) + throw new DMLRuntimeException("unsupported transform method '" + m + "' for get_categorical_mask"); + } + catch(JSONException e) { + throw new DMLRuntimeException(e); + } + } + public void processGetCategorical(ExecutionContext ec, FrameBlock f, ScalarObject spec) { + try { + // 1. extract the spec, 2. validate it JSONObject jSpec = new JSONObject(spec.getStringValue()); + validate(jSpec); - if(!jSpec.containsKey("ids") || !jSpec.getBoolean("ids")) { - throw new DMLRuntimeException("not supported non ID based spec for get_categorical_mask"); - } + // 3.-5. fold each supported transform method into the per-column mask state + CategoricalMask mask = new CategoricalMask(f, jSpec); + mask.hash(); + mask.recode(); + mask.dummycode(); - // get_categorical_mask only models the column expansion of recode/dummycode/hash. - // Methods that change the output arity (bin expands under dummycode, word_embedding and - // bag_of_words map to many columns) or are user-defined (udf) would produce a mask with - // the wrong number of columns, so reject them explicitly instead of emitting a silently - // incorrect result. impute and omit are intentionally allowed: they do not alter the - // output column count or the categorical flag of a column. - for(TfMethod m : new TfMethod[] {TfMethod.BIN, TfMethod.WORD_EMBEDDING, TfMethod.BAG_OF_WORDS, - TfMethod.UDF}) { - if(jSpec.containsKey(m.toString())) - throw new DMLRuntimeException( - "unsupported transform method '" + m + "' for get_categorical_mask"); - } + // 6.-7. size and materialize the output mask + ec.setMatrixOutput(output.getName(), mask.toMatrixBlock()); + } + catch(Exception e) { + throw new DMLRuntimeException(e); + } + } - String recode = TfMethod.RECODE.toString(); - String dummycode = TfMethod.DUMMYCODE.toString(); - String hash = TfMethod.HASH.toString(); + /** + * Accumulates, per input column, how many output columns it expands to (lengths) and whether those + * output columns are categorical (categorical). The arrays are allocated lazily: a column that no + * method touches keeps the implicit default of a single, non-categorical output column. + */ + private static final class CategoricalMask { + private final FrameBlock f; + private final JSONObject jSpec; + private final int nCol; + + private int[] lengths = null; + private boolean[] categorical = null; + + // feature-hashed columns map to K buckets; a plain hashed column produces a single + // (categorical) bucket-id column, while a hashed column that is additionally dummycoded + // expands to K columns. + private boolean[] hashed = null; + private int K = 0; + + private CategoricalMask(FrameBlock f, JSONObject jSpec) { + this.f = f; + this.jSpec = jSpec; + this.nCol = f.getNumColumns(); + } - int[] lengths = new int[nCol]; - // assume all columns encode to at least one column. - Arrays.fill(lengths, 1); - boolean[] categorical = new boolean[nCol]; - - // feature-hashed columns map to K buckets; a plain hashed column - // produces a single (categorical) bucket-id column, while a hashed - // column that is additionally dummycoded expands to K columns. - boolean[] hashed = new boolean[nCol]; - int K = 0; - if(jSpec.containsKey(hash)) { - K = jSpec.getInt("K"); - JSONArray a = jSpec.getJSONArray(hash); - for(Object aa : a) { - int av = (Integer) aa - 1; - hashed[av] = true; - categorical[av] = true; - } + private void hash() throws JSONException { + String hash = TfMethod.HASH.toString(); + if(!jSpec.containsKey(hash)) + return; + K = jSpec.getInt("K"); + hashed = new boolean[nCol]; + ensureCategorical(); + for(Object aa : jSpec.getJSONArray(hash)) { + int av = (Integer) aa - 1; + hashed[av] = true; + categorical[av] = true; } + } - if(jSpec.containsKey(recode)) { - JSONArray a = jSpec.getJSONArray(recode); - for(Object aa : a) { - int av = (Integer) aa - 1; - categorical[av] = true; - } + private void recode() throws JSONException { + String recode = TfMethod.RECODE.toString(); + if(!jSpec.containsKey(recode)) + return; + ensureCategorical(); + for(Object aa : jSpec.getJSONArray(recode)) { + int av = (Integer) aa - 1; + categorical[av] = true; } + } - if(jSpec.containsKey(dummycode)) { - JSONArray a = jSpec.getJSONArray(dummycode); - for(Object aa : a) { - int av = (Integer) aa - 1; - int ndist; - if(hashed[av]) { - // feature hashing followed by dummycoding yields K columns - ndist = K; - } - else { - ColumnMetadata d = f.getColumnMetadata()[av]; - String v = f.getString(0, av); - if(v.length() > 1 && v.charAt(0) == '¿') { - ndist = UtilFunctions.parseToInt(v.substring(1)); - } - else { - ndist = d.isDefault() ? 0 : (int) d.getNumDistinct(); - } - } - lengths[av] = ndist; - categorical[av] = true; - } + private void dummycode() throws JSONException { + String dummycode = TfMethod.DUMMYCODE.toString(); + if(!jSpec.containsKey(dummycode)) + return; + ensureCategorical(); + ensureLengths(); + for(Object aa : jSpec.getJSONArray(dummycode)) { + int av = (Integer) aa - 1; + lengths[av] = distinctCount(av); + categorical[av] = true; } + } - // get total size after mapping + private int distinctCount(int av) { + if(hashed != null && hashed[av]) + // feature hashing followed by dummycoding yields K columns + return K; + ColumnMetadata d = f.getColumnMetadata()[av]; + String v = f.getString(0, av); + if(v.length() > 1 && v.charAt(0) == '¿') + return UtilFunctions.parseToInt(v.substring(1)); + return d.isDefault() ? 0 : (int) d.getNumDistinct(); + } - int sumLengths = 0; - for(int i : lengths) { - sumLengths += i; - } + private int sumLengths() { + if(lengths == null) + return nCol; + int sum = 0; + for(int i = 0; i < nCol; i++) + sum += lengths[i]; + return sum; + } - MatrixBlock ret = new MatrixBlock(1, sumLengths, false); + private MatrixBlock toMatrixBlock() { + MatrixBlock ret = new MatrixBlock(1, sumLengths(), false); ret.allocateDenseBlock(); int off = 0; - for(int i = 0; i < lengths.length; i++) { - for(int j = 0; j < lengths[i]; j++) { - ret.set(0, off++, categorical[i] ? 1 : 0); - } + for(int i = 0; i < nCol; i++) { + int len = (lengths == null) ? 1 : lengths[i]; + double val = (categorical != null && categorical[i]) ? 1 : 0; + for(int j = 0; j < len; j++) + ret.set(0, off++, val); } + return ret; + } - ec.setMatrixOutput(output.getName(), ret); - + private void ensureCategorical() { + if(categorical == null) + categorical = new boolean[nCol]; } - catch(Exception e) { - throw new DMLRuntimeException(e); + + private void ensureLengths() { + if(lengths == null) { + lengths = new int[nCol]; + Arrays.fill(lengths, 1); + } } } } diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java b/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java index b8cc9151fd5..d9c540f54c5 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/GetCategoricalMaskInstructionTest.java @@ -35,6 +35,7 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.ColumnMetadata; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction; import org.apache.sysds.runtime.instructions.cp.BinaryFrameScalarCPInstruction; @@ -84,6 +85,80 @@ public void dummycodeDefaultMetadataContributesNoColumns() { assertEquals(0.0, res.get(0, 0), 0.0); } + @Test + public void noMethodAllColumnsPassThrough() { + // a spec with only "ids" touches no column: every column is a single, non-categorical output + FrameBlock meta = metaWithDistinct(3, new int[] {0, 0, 0}); + MatrixBlock res = run(meta, "{\"ids\": true}"); + + assertMask(res, new double[] {0, 0, 0}); + } + + @Test + public void recodeInterleavedWithPassThrough() { + // categorical (recode, 1 col each) interleaved with continuous pass-through columns + FrameBlock meta = metaWithDistinct(5, new int[] {0, 0, 0, 0, 0}); + MatrixBlock res = run(meta, "{\"ids\": true, \"recode\": [1, 4]}"); + + assertMask(res, new double[] {1, 0, 0, 1, 0}); + } + + @Test + public void leadingPassThroughThenDummycodeOffsets() { + // the dummycode expansion must start at the correct offset after three continuous columns + FrameBlock meta = metaWithDistinct(4, new int[] {0, 0, 0, 3}); + MatrixBlock res = run(meta, "{\"ids\": true, \"dummycode\": [4]}"); + + assertMask(res, new double[] {0, 0, 0, 1, 1, 1}); + } + + @Test + public void multipleDummycodeVaryingDistinctCounts() { + // several dummycoded columns of different widths, all categorical, no pass-through + FrameBlock meta = metaWithDistinct(3, new int[] {2, 4, 1}); + MatrixBlock res = run(meta, "{\"ids\": true, \"dummycode\": [1, 2, 3]}"); + + assertMask(res, new double[] {1, 1, 1, 1, 1, 1, 1}); + } + + @Test + public void dummycodeAndPassThroughAndRecodeInterleaved() { + // dummycode(3) | pass-through | recode | dummycode(2): exercises every offset transition + FrameBlock meta = metaWithDistinct(4, new int[] {3, 0, 0, 2}); + MatrixBlock res = run(meta, "{\"ids\": true, \"recode\": [3], \"dummycode\": [1, 4]}"); + + assertMask(res, new double[] {1, 1, 1, 0, 1, 1, 1}); + } + + @Test + public void recodeAndDummycodeOnSameColumnExpands() { + // a column listed in both recode and dummycode must expand to its dummycode width, not collapse + FrameBlock meta = metaWithDistinct(2, new int[] {4, 0}); + MatrixBlock res = run(meta, "{\"ids\": true, \"recode\": [1], \"dummycode\": [1]}"); + + assertMask(res, new double[] {1, 1, 1, 1, 0}); + } + + @Test + public void hashOnlyColumnStaysSingleCategorical() { + // a hashed-but-not-dummycoded column is a single categorical column; K must not widen it + FrameBlock meta = metaWithDistinct(3, new int[] {0, 0, 0}); + MatrixBlock res = run(meta, "{\"ids\": true, \"hash\": [2], \"K\": 5}"); + + assertMask(res, new double[] {0, 1, 0}); + } + + @Test + public void hashDummycodeRecodePassThroughMixed() { + // col1: hash+dummycode -> K=3 (metadata ignored); col2: pass-through; col3: dummycode(9); + // col4: pass-through; col5: recode. Verifies hashed columns use K while plain dummycode uses + // the metadata distinct count, with correct offsets across the whole row. + FrameBlock meta = metaWithDistinct(5, new int[] {0, 0, 9, 0, 0}); + MatrixBlock res = run(meta, "{\"ids\": true, \"recode\": [5], \"dummycode\": [1, 3], \"hash\": [1], \"K\": 3}"); + + assertMask(res, new double[] {1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1}); + } + @Test public void nonIdSpecMissingIdsKeyThrows() { // a spec without the "ids" key must be rejected, not silently mis-interpreted @@ -142,6 +217,14 @@ public void imputeAndOmitAreAccepted() { assertEquals(1.0, res.get(0, 0), 0.0); } + @Test + public void malformedSpecWrapsJsonException() { + // "ids" present but not a boolean makes spec parsing throw a JSONException, which must be + // wrapped as a DMLRuntimeException rather than propagating raw + FrameBlock meta = new FrameBlock(new ValueType[] {ValueType.STRING}, new String[][] {{"a"}}); + assertThrowsMessage("was not a boolean", () -> run(meta, "{\"ids\": 5, \"recode\": [1]}")); + } + @Test public void unsupportedOpcodeThrows() { // any frame-scalar binary opcode other than get_categorical_mask must be rejected @@ -167,6 +250,37 @@ private static void assertThrowsMessage(String expected, Runnable action) { } } + /** Assert the mask is a single row equal to the expected values (which also fixes its width). */ + private static void assertMask(MatrixBlock res, double[] expected) { + assertEquals(1, res.getNumRows()); + assertEquals(expected.length, res.getNumColumns()); + // compare per cell rather than via getDenseBlockValues(): an all-zero mask has nnz == 0 and + // therefore no materialized dense block + double[] actual = new double[expected.length]; + for(int i = 0; i < expected.length; i++) + actual[i] = res.get(0, i); + assertArrayEquals(expected, actual, 0.0); + } + + /** + * Build a single-row metadata frame of nCol string columns. A positive distinct[i] is written to + * that column's metadata as the recode distinct count (the path real transformencode uses), while + * a zero leaves the column with default metadata (a continuous / non-dummycoded column). + */ + private static FrameBlock metaWithDistinct(int nCol, int[] distinct) { + ValueType[] schema = new ValueType[nCol]; + String[][] data = new String[1][nCol]; + for(int i = 0; i < nCol; i++) { + schema[i] = ValueType.STRING; + data[0][i] = "v"; + } + FrameBlock fb = new FrameBlock(schema, data); + for(int i = 0; i < nCol; i++) + if(distinct[i] > 0) + fb.setColumnMetadata(i, new ColumnMetadata(distinct[i])); + return fb; + } + private static MatrixBlock run(FrameBlock meta, String spec) { ExecutionContext ec = ExecutionContextFactory.createContext(); ec.setAutoCreateVars(true);