Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,18 @@
import org.apache.sysds.runtime.compress.cost.InstructionTypeCounter;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.functionobjects.IndexFunction;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.Mean;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.functionobjects.ReduceCol;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;

public class FederatedWorkloadAnalyzer {
protected static final Log LOG = LogFactory.getLog(FederatedWorkloadAnalyzer.class.getName());
Expand All @@ -55,7 +64,7 @@ public void incrementWorkload(ExecutionContext ec, long tid, Instruction ins) {
}

public void compressRun(ExecutionContext ec, long tid) {
if(counter >= compressRunFrequency ){
if(counter >= compressRunFrequency) {
counter = 0;
get(tid).forEach((K, V) -> CompressedMatrixBlockFactory.compressAsync(ec, Long.toString(K), V));
}
Expand All @@ -68,6 +77,7 @@ private void incrementWorkload(ExecutionContext ec, long tid, ComputationCPInstr
public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap<Long, InstructionTypeCounter> mm,
ComputationCPInstruction cpIns) {
// TODO: Count transitive closure via lineage
// TODO: add more operations
if(cpIns instanceof AggregateBinaryCPInstruction) {
final String n1 = cpIns.input1.getName();
MatrixObject d1 = (MatrixObject) ec.getCacheableData(n1);
Expand All @@ -81,15 +91,45 @@ public void incrementWorkload(ExecutionContext ec, ConcurrentHashMap<Long, Instr
if(validSize(r1, c1)) {
getOrMakeCounter(mm, Long.parseLong(n1)).incRMM(c2);
// safety add overlapping decompress for RMM
getOrMakeCounter(mm, Long.parseLong(n1)).incOverlappingDecompressions();
getOrMakeCounter(mm, Long.parseLong(n1)).incOverlappingDecompressions(c2);
counter++;
}
if(validSize(r2, c2)) {
getOrMakeCounter(mm, Long.parseLong(n2)).incLMM(r1);
counter++;
}

}
else if(cpIns instanceof MMChainCPInstruction) {
final String n1 = cpIns.input1.getName();
getOrMakeCounter(mm, Long.parseLong(n1)).incRMM(1);
getOrMakeCounter(mm, Long.parseLong(n1)).incLMM(1);
counter++;
}
else if(cpIns instanceof AggregateUnaryCPInstruction) {
Operator op = cpIns.getOperator();
final String n1 = cpIns.input1.getName();
long id = Long.parseLong(n1);
if(op instanceof AggregateUnaryOperator) {
AggregateUnaryOperator aop = (AggregateUnaryOperator) op;
IndexFunction idxF = aop.indexFn;
getOrMakeCounter(mm, id).incDictOps();
if(idxF instanceof ReduceCol) {
if((aop.aggOp.increOp.fn instanceof KahanPlus //
|| aop.aggOp.increOp.fn instanceof Plus //
|| aop.aggOp.increOp.fn instanceof Mean)) {
getOrMakeCounter(mm, id).incDictOps();
}
else {
// increment decompression if row reduce.
getOrMakeCounter(mm, id).incDecompressions();
}
}
else {
getOrMakeCounter(mm, id).incDictOps();
}
}
}

}

private static InstructionTypeCounter getOrMakeCounter(ConcurrentHashMap<Long, InstructionTypeCounter> mm, long id) {
Expand Down Expand Up @@ -117,8 +157,8 @@ private static boolean validSize(int nRow, int nCol) {
return nRow > 90 && nRow >= nCol;
}

@Override
public String toString(){
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(this.getClass().getSimpleName());
sb.append(" Counter: ");
Expand Down
Loading
Loading