diff --git a/conf/SystemDS-config.xml.template b/conf/SystemDS-config.xml.template index 88a1c5947ed..153dcb6ef2d 100644 --- a/conf/SystemDS-config.xml.template +++ b/conf/SystemDS-config.xml.template @@ -57,6 +57,10 @@ 64 + + false + false diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java b/src/main/java/org/apache/sysds/conf/DMLConfig.java index e1b7b0bb530..a6339656fb0 100644 --- a/src/main/java/org/apache/sysds/conf/DMLConfig.java +++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java @@ -78,6 +78,7 @@ public class DMLConfig public static final String PARALLEL_ENCODE_NUM_THREADS = "sysds.parallel.encode.numThreads"; public static final String PARALLEL_TOKENIZE = "sysds.parallel.tokenize"; public static final String PARALLEL_TOKENIZE_NUM_BLOCKS = "sysds.parallel.tokenize.numBlocks"; + public static final String FRAME_TO_MATRIX_WARN_CAST = "sysds.frame.tomatrix.warncast"; public static final String COMPRESSED_LINALG = "sysds.compressed.linalg"; public static final String COMPRESSED_LINALG_INTERMEDIATE = "sysds.compressed.linalg.intermediate"; public static final String COMPRESSED_LOSSY = "sysds.compressed.lossy"; @@ -159,6 +160,7 @@ public class DMLConfig _defaultVals.put(IO_COMPRESSION_CODEC, "none"); _defaultVals.put(PARALLEL_TOKENIZE, "false"); _defaultVals.put(PARALLEL_TOKENIZE_NUM_BLOCKS, "64"); + _defaultVals.put(FRAME_TO_MATRIX_WARN_CAST, "false"); _defaultVals.put(PARALLEL_ENCODE, "true" ); _defaultVals.put(PARALLEL_ENCODE_STAGED, "false" ); _defaultVals.put(PARALLEL_ENCODE_APPLY_BLOCKS, "-1"); @@ -456,7 +458,7 @@ public static DMLConfig readConfigurationFile(String configPath) public String getConfigInfo() { String[] tmpConfig = new String[] { LOCAL_TMP_DIR,SCRATCH_SPACE,OPTIMIZATION_LEVEL, DEFAULT_BLOCK_SIZE, - CP_PARALLEL_OPS, CP_PARALLEL_IO, PARALLEL_ENCODE, NATIVE_BLAS, NATIVE_BLAS_DIR, + CP_PARALLEL_OPS, CP_PARALLEL_IO, PARALLEL_ENCODE, FRAME_TO_MATRIX_WARN_CAST, NATIVE_BLAS, NATIVE_BLAS_DIR, COMPRESSED_LINALG, COMPRESSED_LOSSY, COMPRESSED_VALID_COMPRESSIONS, COMPRESSED_OVERLAPPING, COMPRESSED_SAMPLING_RATIO, COMPRESSED_SOFT_REFERENCE_COUNT, COMPRESSED_COCODE, COMPRESSED_TRANSPOSE, COMPRESSED_TRANSFORMENCODE, DAG_LINEARIZATION, diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java index 99cce9f9e97..972a2893fd8 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DoubleArray.java @@ -377,7 +377,7 @@ public static double parseDouble(String value) { return Double.POSITIVE_INFINITY; else if(len == 4 && value.compareToIgnoreCase("-Inf") == 0) return Double.NEGATIVE_INFINITY; - throw new DMLRuntimeException(e); + throw e; } } diff --git a/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java b/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java index 032afe2cd7c..9ff58065d97 100644 --- a/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java +++ b/src/main/java/org/apache/sysds/runtime/frame/data/lib/MatrixBlockFromFrame.java @@ -25,6 +25,8 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; @@ -32,11 +34,17 @@ import org.apache.sysds.runtime.util.CommonThreadPool; import org.apache.sysds.utils.stats.InfrastructureAnalyzer; -public interface MatrixBlockFromFrame { +public class MatrixBlockFromFrame { public static final Log LOG = LogFactory.getLog(MatrixBlockFromFrame.class.getName()); public static final int blocksizeIJ = 32; + public static Boolean WARNED_FOR_FAILED_CAST = false; + + private MatrixBlockFromFrame(){ + // private constructor for code coverage. + } + /** * Converts a frame block with arbitrary schema into a matrix block. Since matrix block only supports value type * double, we do a best effort conversion of non-double types which might result in errors for non-numerical data. @@ -68,11 +76,15 @@ public static MatrixBlock convertToMatrixBlock(FrameBlock frame, MatrixBlock ret if(k == -1) k = InfrastructureAnalyzer.getLocalParallelism(); + // Read once on the calling thread: the thread-local config is not visible to pool workers. + final boolean warnCast = ConfigurationManager.getDMLConfig() + .getBooleanValue(DMLConfig.FRAME_TO_MATRIX_WARN_CAST); + long nnz = 0; if(k == 1) - nnz = convert(frame, ret, n, 0, m); + nnz = convert(frame, ret, n, 0, m, warnCast); else - nnz = convertParallel(frame, ret, m, n, k); + nnz = convertParallel(frame, ret, m, n, k, warnCast); ret.setNonZeros(nnz); ret.examSparsity(); @@ -93,14 +105,37 @@ else if(ret.getNumRows() != m || ret.getNumColumns() != n || ret.isInSparseForma return ret; } - private static long convert(FrameBlock frame, MatrixBlock mb, int n, int rl, int ru) { + private static long convert(FrameBlock frame, MatrixBlock mb, int n, int rl, int ru, boolean warnCast) { + // Strict (default): let number format errors propagate and fail the conversion. + if(!warnCast) + return convertStrict(frame, mb, n, rl, ru); + + // Warn-only: on number format errors fall back to writing NaN for the incompatible cells. + try { + return convertStrict(frame, mb, n, rl, ru); + } + catch(NumberFormatException | DMLRuntimeException e) { + synchronized(WARNED_FOR_FAILED_CAST){ + if(!WARNED_FOR_FAILED_CAST) { + LOG.error( + "Failed to convert to Matrix because of number format errors, falling back to NaN on incompatible cells", + e); + WARNED_FOR_FAILED_CAST = true; + } + } + return convertSafeCast(frame, mb, n, rl, ru); + } + } + + private static long convertStrict(FrameBlock frame, MatrixBlock mb, int n, int rl, int ru) { if(mb.getDenseBlock().isContiguous()) return convertContiguous(frame, mb, n, rl, ru); else return convertGeneric(frame, mb, n, rl, ru); } - private static long convertParallel(FrameBlock frame, MatrixBlock mb, int m, int n, int k) throws Exception { + private static long convertParallel(FrameBlock frame, MatrixBlock mb, int m, int n, int k, boolean warnCast) + throws Exception { ExecutorService pool = CommonThreadPool.get(k); try { List> tasks = new ArrayList<>(); @@ -109,7 +144,7 @@ private static long convertParallel(FrameBlock frame, MatrixBlock mb, int m, int for(int i = 0; i < m; i += blkz) { final int start = i; final int end = Math.min(i + blkz, m); - tasks.add(pool.submit(() -> convert(frame, mb, n, start, end))); + tasks.add(pool.submit(() -> convert(frame, mb, n, start, end, warnCast))); } long nnz = 0; @@ -169,4 +204,37 @@ private static long convertBlockGeneric(final FrameBlock frame, long lnnz, final } return lnnz; } + + private static long convertSafeCast(final FrameBlock frame, final MatrixBlock mb, final int n, final int rl, + final int ru) { + final DenseBlock c = mb.getDenseBlock(); + long lnnz = 0; + for(int bi = rl; bi < ru; bi += blocksizeIJ) { + for(int bj = 0; bj < n; bj += blocksizeIJ) { + int bimin = Math.min(bi + blocksizeIJ, ru); + int bjmin = Math.min(bj + blocksizeIJ, n); + lnnz = convertBlockSafeCast(frame, lnnz, c, bi, bj, bimin, bjmin); + } + } + return lnnz; + } + + private static long convertBlockSafeCast(final FrameBlock frame, long lnnz, final DenseBlock c, final int rl, + final int cl, final int ru, final int cu) { + for(int i = rl; i < ru; i++) { + final double[] cvals = c.values(i); + final int cpos = c.pos(i); + for(int j = cl; j < cu; j++) { + try { + lnnz += (cvals[cpos + j] = frame.getDoubleNaN(i, j)) != 0 ? 1 : 0; + } + catch(NumberFormatException | DMLRuntimeException e) { + lnnz += 1; + cvals[cpos + j] = Double.NaN; + } + } + } + return lnnz; + } + } diff --git a/src/main/java/org/apache/sysds/utils/DoubleParser.java b/src/main/java/org/apache/sysds/utils/DoubleParser.java index 9c77a3e95c8..c0122f8061f 100644 --- a/src/main/java/org/apache/sysds/utils/DoubleParser.java +++ b/src/main/java/org/apache/sysds/utils/DoubleParser.java @@ -184,7 +184,7 @@ public interface DoubleParser { 0x8e679c2f5e44ff8fL}; public static double parseFloatingPointLiteral(String str, int offset, int endIndex) { - if(endIndex > 100) + if(endIndex > 100)// long string return Double.parseDouble(str); // Skip leading whitespace int index = skipWhitespace(str, offset, endIndex); @@ -197,9 +197,10 @@ public static double parseFloatingPointLiteral(String str, int offset, int endIn } // Parse NaN or Infinity (this occurs rarely) - if(ch >= 'I') - return Double.parseDouble(str); - else if(str.charAt(endIndex - 1) >= 'a') + // : is the first character after numbers. + // 0 is the first number. + // we use the last position, since this is not allowed to be other values than a number. + if(str.charAt(endIndex - 1) > '9' || str.charAt(endIndex - 1) < '0') return Double.parseDouble(str); final double val = parseDecFloatLiteral(str, index, offset, endIndex); diff --git a/src/test/java/org/apache/sysds/test/component/frame/MatrixFromFrameSafeCastTest.java b/src/test/java/org/apache/sysds/test/component/frame/MatrixFromFrameSafeCastTest.java new file mode 100644 index 00000000000..43a53879f17 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/frame/MatrixFromFrameSafeCastTest.java @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.frame; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Modifier; +import java.util.List; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.log4j.spi.LoggingEvent; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.conf.DMLConfig; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; +import org.apache.sysds.runtime.frame.data.columns.ArrayFactory; +import org.apache.sysds.runtime.frame.data.lib.MatrixBlockFromFrame; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.test.LoggingUtils; +import org.apache.sysds.test.LoggingUtils.TestAppender; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +/** + * Exercises the defensive NaN fallback in {@link MatrixBlockFromFrame} that triggers when a frame contains cells that + * cannot be parsed into doubles. The fallback is gated behind {@link DMLConfig#FRAME_TO_MATRIX_WARN_CAST}. + */ +public class MatrixFromFrameSafeCastTest { + protected static final Log LOG = LogFactory.getLog(MatrixFromFrameSafeCastTest.class.getName()); + + /** Captures the expected fallback LOG.error so it does not pollute test output. */ + private TestAppender appender; + + private void setWarnCast(boolean enabled) { + ConfigurationManager.getDMLConfig().setTextValue(DMLConfig.FRAME_TO_MATRIX_WARN_CAST, String.valueOf(enabled)); + } + + @Before + public void setUp() { + appender = LoggingUtils.overwrite(); + MatrixBlockFromFrame.WARNED_FOR_FAILED_CAST = false; + setWarnCast(true); + } + + @After + public void tearDown() { + LoggingUtils.reinsert(appender); + // restore the strict (default) behavior to avoid leaking into other tests + setWarnCast(false); + MatrixBlockFromFrame.WARNED_FOR_FAILED_CAST = false; + } + + private static final double NA = Double.NaN; + + /** Expected matrix for {@link #mixedFrame()}: parseable cells keep their value, unparseable cells become NaN. */ + private static final double[][] EXPECTED = {{1.0, 4.0}, {NA, 5.0}, {3.0, NA}}; + + /** + * Build a string frame mixing parseable numbers with values that cannot be cast to double. The non-numeric cells + * force the conversion onto the safe-cast path. + */ + private static FrameBlock mixedFrame() { + Array c1 = ArrayFactory.create(new String[] {"1.0", "abc", "3.0"}); + Array c2 = ArrayFactory.create(new String[] {"4.0", "5.0", "xyz"}); + return new FrameBlock(new Array[] {c1, c2}); + } + + @Test + public void safeCastSingleThread() { + FrameBlock fb = mixedFrame(); + MatrixBlock mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, 1); + compareSafeCast(mb); + } + + @Test + public void safeCastParallel() { + FrameBlock fb = mixedFrame(); + MatrixBlock mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, 4); + compareSafeCast(mb); + } + + @Test + public void safeCastProvidedOutput() { + FrameBlock fb = mixedFrame(); + MatrixBlock mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, new MatrixBlock(3, 2, false), 1); + compareSafeCast(mb); + } + + @Test + public void safeCastNonContiguous() { + FrameBlock fb = mixedFrame(); + MatrixBlock mb = new MatrixBlock(fb.getNumRows(), fb.getNumColumns(), false); + mb.allocateBlock(); + DenseBlock spy = spy(mb.getDenseBlock()); + when(spy.isContiguous()).thenReturn(false); + mb.setDenseBlock(spy); + + mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, mb, 1); + compareSafeCast(mb); + } + + @Test + public void safeCastWarnsOnlyOnce() { + FrameBlock fb = mixedFrame(); + + MatrixBlock first = MatrixBlockFromFrame.convertToMatrixBlock(fb, 1); + assertTrue("Conversion should flag that it fell back to NaN casting", + MatrixBlockFromFrame.WARNED_FOR_FAILED_CAST); + compareSafeCast(first); + + // second conversion takes the already-warned branch + MatrixBlock second = MatrixBlockFromFrame.convertToMatrixBlock(fb, 1); + compareSafeCast(second); + + // the fallback warning must be logged exactly once across both conversions + final List log = LoggingUtils.reinsert(appender); + long warnings = log.stream() + .filter(l -> l.getMessage().toString().contains("falling back to NaN on incompatible cells")) + .count(); + assertEquals(1, warnings); + } + + @Test + public void strictThrowsWhenWarnCastDisabled() { + // default behavior: incompatible cells fail the whole conversion + setWarnCast(false); + FrameBlock fb = mixedFrame(); + + Exception e = assertThrows(DMLRuntimeException.class, + () -> MatrixBlockFromFrame.convertToMatrixBlock(fb, 1)); + assertTrue(e.getMessage().contains("Failed to convert FrameBlock to MatrixBlock")); + } + + @Test + public void strictThrowsParallelWhenWarnCastDisabled() { + // default behavior must also fail fast on the multi-threaded path + setWarnCast(false); + FrameBlock fb = mixedFrame(); + + Exception e = assertThrows(DMLRuntimeException.class, + () -> MatrixBlockFromFrame.convertToMatrixBlock(fb, 4)); + assertTrue(e.getMessage().contains("Failed to convert FrameBlock to MatrixBlock")); + } + + @Test + public void warnCastValidFrameConvertsWithoutFallback() { + // warn-cast enabled but every cell is parseable: the strict path succeeds and the NaN + // fallback must never trigger (covers the try-succeeds branch of convert). + Array c1 = ArrayFactory.create(new String[] {"1.0", "2.0", "3.0"}); + Array c2 = ArrayFactory.create(new String[] {"4.0", "5.0", "6.0"}); + FrameBlock fb = new FrameBlock(new Array[] {c1, c2}); + + MatrixBlock mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, 1); + + compare(new double[][] {{1.0, 4.0}, {2.0, 5.0}, {3.0, 6.0}}, mb); + assertFalse("No cells failed to parse, so the fallback must not have been used", + MatrixBlockFromFrame.WARNED_FOR_FAILED_CAST); + + final List log = LoggingUtils.reinsert(appender); + long warnings = log.stream() + .filter(l -> l.getMessage().toString().contains("falling back to NaN on incompatible cells")) + .count(); + assertEquals(0, warnings); + } + + @Test + public void safeCastZeroValues() { + // zero-valued parseable cells must not contribute to the non-zero count even on the safe-cast + // path (covers the ': 0' branch of the nnz ternary), while unparseable cells still become NaN. + Array c1 = ArrayFactory.create(new String[] {"0.0", "abc"}); + Array c2 = ArrayFactory.create(new String[] {"2.0", "0.0"}); + FrameBlock fb = new FrameBlock(new Array[] {c1, c2}); + + MatrixBlock mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, 1); + + compare(new double[][] {{0.0, 2.0}, {NA, 0.0}}, mb); + // non-zeros: 2.0 and the NaN cell count, the two explicit zeros do not + assertEquals(2, mb.getNonZeros()); + } + + @Test + public void safeCastAllInvalid() { + // every cell fails to parse: the whole matrix becomes NaN and each NaN counts as a non-zero + Array c1 = ArrayFactory.create(new String[] {"abc", "def"}); + Array c2 = ArrayFactory.create(new String[] {"ghi", "jkl"}); + FrameBlock fb = new FrameBlock(new Array[] {c1, c2}); + + MatrixBlock mb = MatrixBlockFromFrame.convertToMatrixBlock(fb, 1); + + compare(new double[][] {{NA, NA}, {NA, NA}}, mb); + assertEquals(4, mb.getNonZeros()); + } + + @Test + public void privateConstructor() throws Exception { + Constructor c = MatrixBlockFromFrame.class.getDeclaredConstructor(); + assertTrue("Constructor should be private", Modifier.isPrivate(c.getModifiers())); + c.setAccessible(true); + c.newInstance(); + } + + /** + * Verify that every parseable cell matches its expected value and every unparseable cell became NaN. + */ + private static void compareSafeCast(MatrixBlock mb) { + compare(EXPECTED, mb); + } + + /** + * Verify that the matrix matches the expected values cell by cell, treating NaN cells as expected NaN. + */ + private static void compare(double[][] expected, MatrixBlock mb) { + assertEquals(expected.length, mb.getNumRows()); + assertEquals(expected[0].length, mb.getNumColumns()); + for(int i = 0; i < expected.length; i++) + for(int j = 0; j < expected[i].length; j++) + assertEquals("cell (" + i + "," + j + ")", expected[i][j], mb.get(i, j), 0.0); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java index 73d04f32435..df386d4659d 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java +++ b/src/test/java/org/apache/sysds/test/component/frame/array/CustomArrayTests.java @@ -1899,6 +1899,13 @@ public void parseDoubleInvalid3() { assertEquals(Double.POSITIVE_INFINITY, DoubleArray.parseDouble("iff"), 0.0); } + @Test(expected = NumberFormatException.class) + public void parseDoubleThrowsRawNumberFormatException() { + // the parse failure must surface as the raw NumberFormatException, not a wrapped DMLRuntimeException, + // so callers can distinguish a format error from other runtime failures + DoubleArray.parseDouble("not_a_number"); + } + @Test(expected = Exception.class) public void setDDCArrayWithDDCArray() { Array c = FrameCompressTestUtils.generateArray(100, 32, 5, ValueType.INT32); diff --git a/src/test/java/org/apache/sysds/test/component/misc/DoubleParserTest.java b/src/test/java/org/apache/sysds/test/component/misc/DoubleParserTest.java index 08aa5a94e93..546d6ec1387 100644 --- a/src/test/java/org/apache/sysds/test/component/misc/DoubleParserTest.java +++ b/src/test/java/org/apache/sysds/test/component/misc/DoubleParserTest.java @@ -152,6 +152,18 @@ public void parseWithWhitespace() { compareToDoubleParser(" 132.14"); } + @Test + public void parseTrailingDot() { + // last char '.' is below '0', forcing the slow Double.parseDouble path + compareToDoubleParser("132."); + } + + @Test + public void parseTrailingWhitespace() { + // last char ' ' is below '0', forcing the slow Double.parseDouble path + compareToDoubleParser("132.14 "); + } + @Test public void parsePowerOf10(){ compareToDoubleParser("132e10");