diff --git a/src/main/java/org/apache/sysds/hops/BinaryOp.java b/src/main/java/org/apache/sysds/hops/BinaryOp.java index 2b803a053c1..c28be3c5925 100644 --- a/src/main/java/org/apache/sysds/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/BinaryOp.java @@ -763,8 +763,8 @@ protected ExecType optFindExecType(boolean transitive) { checkAndSetForcedPlatform(); - DataType dt1 = getInput().get(0).getDataType(); - DataType dt2 = getInput().get(1).getDataType(); + final DataType dt1 = getInput(0).getDataType(); + final DataType dt2 = getInput(1).getDataType(); if( _etypeForced != null ) { setExecType(_etypeForced); @@ -812,18 +812,28 @@ else if ( dt1 == DataType.SCALAR && dt2 == DataType.MATRIX ) { checkAndSetInvalidCPDimsAndSize(); } - //spark-specific decision refinement (execute unary scalar w/ spark input and + // spark-specific decision refinement (execute unary scalar w/ spark input and // single parent also in spark because it's likely cheap and reduces intermediates) - if(transitive && _etype == ExecType.CP && _etypeForced != ExecType.CP && _etypeForced != ExecType.FED && - getDataType().isMatrix() // output should be a matrix - && (dt1.isScalar() || dt2.isScalar()) // one side should be scalar - && supportsMatrixScalarOperations() // scalar operations - && !(getInput().get(dt1.isScalar() ? 1 : 0) instanceof DataOp) // input is not checkpoint - && getInput().get(dt1.isScalar() ? 1 : 0).getParent().size() == 1 // unary scalar is only parent - && !HopRewriteUtils.isSingleBlock(getInput().get(dt1.isScalar() ? 1 : 0)) // single block triggered exec - && getInput().get(dt1.isScalar() ? 1 : 0).optFindExecType() == ExecType.SPARK) { - // pull unary scalar operation into spark - _etype = ExecType.SPARK; + if(transitive // we allow transitive Spark operations. continue sequences of spark operations + && _etype == ExecType.CP // The instruction is currently in CP + && _etypeForced != ExecType.CP // not forced CP + && _etypeForced != ExecType.FED // not federated + && (getDataType().isMatrix() || getDataType().isFrame()) // output should be a matrix or frame + ) { + final boolean v1 = getInput(0).isScalarOrVectorBellowBlockSize(); + final boolean v2 = getInput(1).isScalarOrVectorBellowBlockSize(); + final boolean left = v1 == true; // left side is the vector or scalar + final Hop sparkIn = getInput(left ? 1 : 0); + if((v1 ^ v2) // XOR only one side is allowed to be a vector or a scalar. + && (supportsMatrixScalarOperations() || op == OpOp2.APPLY_SCHEMA) // supported operation + && sparkIn.getParent().size() == 1 // only one parent + && !HopRewriteUtils.isSingleBlock(sparkIn) // single block triggered exec + && sparkIn.optFindExecType() == ExecType.SPARK // input was spark op. + && !(sparkIn instanceof DataOp) // input is not checkpoint + ) { + // pull operation into spark + _etype = ExecType.SPARK; + } } if( OptimizerUtils.ALLOW_BINARY_UPDATE_IN_PLACE && @@ -853,7 +863,7 @@ else if( (op == OpOp2.CBIND && getDataType().isList()) || (op == OpOp2.RBIND && getDataType().isList())) { _etype = ExecType.CP; } - + //mark for recompile (forever) setRequiresRecompileIfNecessary(); diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java index 86749d44c1c..675fbb380a1 100644 --- a/src/main/java/org/apache/sysds/hops/Hop.java +++ b/src/main/java/org/apache/sysds/hops/Hop.java @@ -1045,6 +1045,12 @@ public final String toString() { // ======================================================================================== + protected boolean isScalarOrVectorBellowBlockSize(){ + return getDataType().isScalar() || (dimsKnown() && + (( _dc.getRows() == 1 && _dc.getCols() < ConfigurationManager.getBlocksize()) + || _dc.getCols() == 1 && _dc.getRows() < ConfigurationManager.getBlocksize())); + } + protected boolean isVector() { return (dimsKnown() && (_dc.getRows() == 1 || _dc.getCols() == 1) ); } @@ -1629,6 +1635,11 @@ protected void setMemoryAndComputeEstimates(Lop lop) { lop.setComputeEstimate(ComputeCost.getHOPComputeCost(this)); } + protected boolean hasSparkOutput(){ + return (this.optFindExecType() == ExecType.SPARK + || (this instanceof DataOp && ((DataOp)this).hasOnlyRDD())); + } + /** * Set parse information. * diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java b/src/main/java/org/apache/sysds/hops/UnaryOp.java index b3475edfbae..73e24eb17e2 100644 --- a/src/main/java/org/apache/sysds/hops/UnaryOp.java +++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java @@ -366,7 +366,11 @@ protected double computeOutputMemEstimate( long dim1, long dim2, long nnz ) } else { sparsity = OptimizerUtils.getSparsity(dim1, dim2, nnz); } - return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity, getDataType()); + + if(getDataType() == DataType.FRAME) + return OptimizerUtils.estimateSizeExactFrame(dim1, dim2); + else + return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity); } @Override @@ -463,6 +467,13 @@ public boolean isMetadataOperation() { || _op == OpOp1.CAST_AS_LIST; } + private boolean isDisallowedSparkOps(){ + return isCumulativeUnaryOperation() + || isCastUnaryOperation() + || _op==OpOp1.MEDIAN + || _op==OpOp1.IQM; + } + @Override protected ExecType optFindExecType(boolean transitive) { @@ -493,19 +504,22 @@ else if ( getInput().get(0).areDimsBelowThreshold() || getInput().get(0).isVecto checkAndSetInvalidCPDimsAndSize(); } + //spark-specific decision refinement (execute unary w/ spark input and //single parent also in spark because it's likely cheap and reduces intermediates) - if( _etype == ExecType.CP && _etypeForced != ExecType.CP - && getInput().get(0).optFindExecType() == ExecType.SPARK - && getDataType().isMatrix() - && !isCumulativeUnaryOperation() && !isCastUnaryOperation() - && _op!=OpOp1.MEDIAN && _op!=OpOp1.IQM - && !(getInput().get(0) instanceof DataOp) //input is not checkpoint - && getInput().get(0).getParent().size()==1 ) //unary is only parent - { + if(_etype == ExecType.CP // currently CP instruction + && _etype != ExecType.SPARK /// currently not SP. + && _etypeForced != ExecType.CP // not forced as CP instruction + && getInput(0).hasSparkOutput() // input is a spark instruction + && (getDataType().isMatrix() || getDataType().isFrame()) // output is a matrix or frame + && !isDisallowedSparkOps() // is invalid spark instruction + // && !(getInput().get(0) instanceof DataOp) // input is not checkpoint + // && getInput(0).getParent().size() <= 1// unary is only parent + ) { //pull unary operation into spark _etype = ExecType.SPARK; } + //mark for recompile (forever) setRequiresRecompileIfNecessary(); @@ -520,7 +534,7 @@ && getInput().get(0).getParent().size()==1 ) //unary is only parent } else { setRequiresRecompileIfNecessary(); } - + return _etype; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java index 99693635a9b..948a78f96af 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCompAgg.java @@ -486,7 +486,7 @@ private static List> generateUnaryAggregateOverlappingFuture final ArrayList tasks = new ArrayList<>(); final int nCol = m1.getNumColumns(); final int nRow = m1.getNumRows(); - final int blklen = Math.max(64, nRow / k); + final int blklen = Math.max(64, (nRow + k) / k); final List groups = m1.getColGroups(); final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups); if(shouldFilter) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java index d82d58e323e..cc7953f8c5d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibMMChain.java @@ -30,6 +30,7 @@ import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.functionobjects.Multiply; +import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.matrix.data.LibMatrixBincell; import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -95,6 +96,11 @@ public static MatrixBlock mmChain(CompressedMatrixBlock x, MatrixBlock v, Matrix if(x.isEmpty()) return returnEmpty(x, out); + if(ctype == ChainType.XtXv && x.getColGroups().size() < 5 && x.getNumColumns()> 30){ + MatrixBlock tmp = CLALibTSMM.leftMultByTransposeSelf(x, k); + return tmp.aggregateBinaryOperations(tmp, v, out, InstructionUtils.getMatMultOperator(k)); + } + // Morph the columns to efficient types for the operation. x = filterColGroups(x); double preFilterTime = t.stop(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java index f14d6833d95..ce06262b9a5 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java @@ -31,6 +31,8 @@ import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ASDC; +import org.apache.sysds.runtime.compress.colgroup.ASDCZero; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; @@ -71,10 +73,10 @@ public static MatrixBlock rightMultByMatrix(CompressedMatrixBlock m1, MatrixBloc if(m2 instanceof CompressedMatrixBlock) m2 = ((CompressedMatrixBlock) m2).getUncompressed("Uncompressed right side of right MM", k); - if(betterIfDecompressed(m1)) { - // perform uncompressed multiplication. - return decompressingMatrixMult(m1, m2, k); - } + // if(betterIfDecompressed(m1)) { + // // perform uncompressed multiplication. + // return decompressingMatrixMult(m1, m2, k); + // } if(!allowOverlap) { LOG.trace("Overlapping output not allowed in call to Right MM"); @@ -143,7 +145,9 @@ private static MatrixBlock decompressingMatrixMult(CompressedMatrixBlock m1, Mat private static boolean betterIfDecompressed(CompressedMatrixBlock m) { for(AColGroup g : m.getColGroups()) { - if(!(g instanceof ColGroupUncompressed) && g.getNumValues() * 2 >= m.getNumRows()) { + // TODO add subpport for decompressing RMM to ASDC and ASDCZero + if(!(g instanceof ColGroupUncompressed || g instanceof ASDC || g instanceof ASDCZero) && + g.getNumValues() * 2 >= m.getNumRows()) { return true; } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java index a1d47a9b150..d0396b63810 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibTSMM.java @@ -31,6 +31,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.colgroup.AColGroup; +import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.matrix.data.LibMatrixMult; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.CommonThreadPool; @@ -42,6 +43,10 @@ private CLALibTSMM() { // private constructor } + public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, int k) { + return leftMultByTransposeSelf(cmb, new MatrixBlock(), k); + } + /** * Self left Matrix multiplication (tsmm) * @@ -51,24 +56,32 @@ private CLALibTSMM() { * @param ret The output matrix to put the result into * @param k The parallelization degree allowed */ - public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { + public static MatrixBlock leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBlock ret, int k) { + final int numColumns = cmb.getNumColumns(); + final int numRows = cmb.getNumRows(); + if(cmb.isEmpty()) + return new MatrixBlock(numColumns, numColumns, true); + // create output matrix block + if(ret == null) + ret = new MatrixBlock(numColumns, numColumns, false); + else + ret.reset(numColumns, numColumns, false); + ret.allocateDenseBlock(); final List groups = cmb.getColGroups(); - final int numColumns = cmb.getNumColumns(); - if(groups.size() >= numColumns) { + if(groups.size() >= numColumns || containsUncompressedColGroup(groups)) { MatrixBlock m = cmb.getUncompressed("TSMM to many columngroups", k); LibMatrixMult.matrixMultTransposeSelf(m, ret, true, k); - return; + return ret; } - final int numRows = cmb.getNumRows(); final boolean shouldFilter = CLALibUtils.shouldPreFilter(groups); final boolean overlapping = cmb.isOverlapping(); if(shouldFilter) { final double[] constV = new double[numColumns]; final List filteredGroups = CLALibUtils.filterGroups(groups, constV); tsmmColGroups(filteredGroups, ret, numRows, overlapping, k); - addCorrectionLayer(filteredGroups, ret, numRows, numColumns, constV); + addCorrectionLayer(filteredGroups, ret, numRows, numColumns, constV, k); } else { @@ -77,17 +90,23 @@ public static void leftMultByTransposeSelf(CompressedMatrixBlock cmb, MatrixBloc ret.setNonZeros(LibMatrixMult.copyUpperToLowerTriangle(ret)); ret.examSparsity(); + return ret; + } + + private static boolean containsUncompressedColGroup(List groups) { + for(AColGroup g : groups) + if(g instanceof ColGroupUncompressed) + return true; + return false; } private static void addCorrectionLayer(List filteredGroups, MatrixBlock result, int nRows, int nCols, - double[] constV) { + double[] constV, int k) { final double[] retV = result.getDenseBlockValues(); final double[] filteredColSum = CLALibUtils.getColSum(filteredGroups, nCols, nRows); addCorrectionLayer(constV, filteredColSum, nRows, retV); } - - private static void tsmmColGroups(List groups, MatrixBlock ret, int nRows, boolean overlapping, int k) { if(k <= 1) tsmmColGroupsSingleThread(groups, ret, nRows); @@ -136,12 +155,12 @@ private static void tsmmColGroupsMultiThread(List groups, MatrixBlock public static void addCorrectionLayer(double[] constV, double[] filteredColSum, int nRow, double[] ret) { final int nColRow = constV.length; - for(int row = 0; row < nColRow; row++){ + for(int row = 0; row < nColRow; row++) { int offOut = nColRow * row; final double v1l = constV[row]; final double v2l = filteredColSum[row] + constV[row] * nRow; - for(int col = row; col < nColRow; col++){ - ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col]; + for(int col = row; col < nColRow; col++) { + ret[offOut + col] += v1l * filteredColSum[col] + v2l * constV[col]; } } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java index 119589a3033..e53958ac4b8 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java @@ -352,7 +352,7 @@ else if(opcode.equalsIgnoreCase(Opcodes.TRANSFORMDECODE.toString())) { // compute transformdecode Decoder decoder = DecoderFactory .createDecoder(getParameterMap().get("spec"), colnames, null, meta, data.getNumColumns()); - FrameBlock fbout = decoder.decode(data, new FrameBlock(decoder.getSchema())); + FrameBlock fbout = decoder.decode(data, new FrameBlock(decoder.getSchema()), InfrastructureAnalyzer.getLocalParallelism()); fbout.setColumnNames(Arrays.copyOfRange(colnames, 0, fbout.getNumColumns())); // release locks diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java index 359df747e7b..0f707b74412 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java @@ -44,6 +44,7 @@ import org.apache.sysds.runtime.controlprogram.parfor.util.IDSequence; import org.apache.sysds.runtime.data.TensorBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.lib.MatrixBlockFromFrame; import org.apache.sysds.runtime.instructions.Instruction; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.ooc.TeeOOCInstruction; @@ -923,7 +924,7 @@ private void processCastAsMatrixVariableInstruction(ExecutionContext ec) { switch( getInput1().getDataType() ) { case FRAME: { FrameBlock fin = ec.getFrameInput(getInput1().getName()); - MatrixBlock out = DataConverter.convertToMatrixBlock(fin); + MatrixBlock out = MatrixBlockFromFrame.convertToMatrixBlock(fin, k); ec.releaseFrameInput(getInput1().getName()); ec.setMatrixOutput(output.getName(), out); break; diff --git a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java index f4bc9f8b216..f1afcfac194 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java @@ -62,17 +62,20 @@ public FrameBlock decode(MatrixBlock in, FrameBlock out) { @Override public FrameBlock decode(final MatrixBlock in, final FrameBlock out, final int k) { + if(k <= 1) + return decode(in, out); final ExecutorService pool = CommonThreadPool.get(k); out.ensureAllocatedColumns(in.getNumRows()); try { final List> tasks = new ArrayList<>(); int blz = Math.max(in.getNumRows() / k, 1000); - for(Decoder decoder : _decoders){ - for(int i = 0; i < in.getNumRows(); i += blz){ - final int start = i; - final int end = Math.min(in.getNumRows(), i + blz); - tasks.add(pool.submit(() -> decoder.decode(in, out, start, end))); - } + // Parallelize over row blocks (not over decoders): all decoders must + // run in order within a block, e.g. recode-on-output depends on the + // category indexes produced by the preceding dummycode decoder. + for(int i = 0; i < in.getNumRows(); i += blz){ + final int start = i; + final int end = Math.min(in.getNumRows(), i + blz); + tasks.add(pool.submit(() -> decode(in, out, start, end))); } for(Future f : tasks) f.get();