package org.jruby.ir.representations;
import org.jruby.dirgra.DirectedGraph;
import org.jruby.dirgra.Edge;
import org.jruby.ir.IRManager;
import org.jruby.ir.IRScope;
import org.jruby.ir.Operation;
import org.jruby.ir.instructions.*;
import org.jruby.ir.operands.Label;
import org.jruby.ir.operands.Operand;
import org.jruby.ir.operands.Variable;
import org.jruby.ir.operands.WrappedIRClosure;
import org.jruby.ir.transformations.inlining.CloneInfo;
import org.jruby.util.log.Logger;
import org.jruby.util.log.LoggerFactory;
import java.util.*;
public class CFG {
public enum EdgeType {
REGULAR,
EXCEPTION,
FALL_THROUGH,
EXIT
}
private static final Logger LOG = LoggerFactory.getLogger(CFG.class);
private IRScope scope;
private Map<Label, BasicBlock> bbMap;
private Map<BasicBlock, BasicBlock> rescuerMap;
private BasicBlock entryBB;
private BasicBlock exitBB;
private BasicBlock globalEnsureBB;
private DirectedGraph<BasicBlock> graph;
private int nextBBId;
LinkedList<BasicBlock> postOrderList;
public CFG(IRScope scope) {
this.scope = scope;
this.graph = new DirectedGraph<>();
this.bbMap = new HashMap<>();
this.rescuerMap = new HashMap<>();
this.nextBBId = 0;
this.entryBB = this.exitBB = null;
this.globalEnsureBB = null;
this.postOrderList = null;
}
public int getNextBBID() {
nextBBId++;
return nextBBId;
}
public IRManager getManager() {
return scope.getManager();
}
public int getMaxNodeID() {
return nextBBId;
}
public boolean bbIsProtected(BasicBlock b) {
return getRescuerBBFor(b) != null;
}
public BasicBlock getBBForLabel(Label label) {
return bbMap.get(label);
}
public BasicBlock getEntryBB() {
return entryBB;
}
public BasicBlock getExitBB() {
return exitBB;
}
public BasicBlock getGlobalEnsureBB() {
return globalEnsureBB;
}
public LinkedList<BasicBlock> postOrderList() {
if (postOrderList == null) postOrderList = buildPostOrderList();
return postOrderList;
}
public Iterator<BasicBlock> getPostOrderTraverser() {
return postOrderList().iterator();
}
public Iterator<BasicBlock> getReversePostOrderTraverser() {
return postOrderList().descendingIterator();
}
public void resetState() {
postOrderList = null;
}
public IRScope getScope() {
return scope;
}
public int size() {
return graph.size();
}
public Collection<BasicBlock> getBasicBlocks() {
return graph.allData();
}
public Collection<BasicBlock> getSortedBasicBlocks() {
return graph.getInorderData();
}
public void addEdge(BasicBlock source, BasicBlock destination, Object type) {
graph.findOrCreateVertexFor(source).addEdgeTo(destination, type);
}
public int inDegree(BasicBlock b) {
return graph.findVertexFor(b).inDegree();
}
public int outDegree(BasicBlock b) {
return graph.findVertexFor(b).outDegree();
}
public Iterable<BasicBlock> getIncomingSources(BasicBlock block) {
return graph.findVertexFor(block).getIncomingSourcesData();
}
public Iterable<Edge<BasicBlock>> getIncomingEdges(BasicBlock block) {
return graph.findVertexFor(block).getIncomingEdges();
}
public BasicBlock getIncomingSourceOfType(BasicBlock block, Object type) {
return graph.findVertexFor(block).getIncomingSourceDataOfType(type);
}
public BasicBlock getOutgoingDestinationOfType(BasicBlock block, Object type) {
return graph.findVertexFor(block).getOutgoingDestinationDataOfType(type);
}
public Iterable<BasicBlock> getOutgoingDestinations(BasicBlock block) {
return graph.findVertexFor(block).getOutgoingDestinationsData();
}
public Iterable<BasicBlock> getOutgoingDestinationsOfType(BasicBlock block, Object type) {
return graph.findVertexFor(block).getOutgoingDestinationsDataOfType(type);
}
public Iterable<BasicBlock> getOutgoingDestinationsNotOfType(BasicBlock block, Object type) {
return graph.findVertexFor(block).getOutgoingDestinationsDataNotOfType(type);
}
public Collection<Edge<BasicBlock>> getOutgoingEdges(BasicBlock block) {
return graph.findVertexFor(block).getOutgoingEdges();
}
public BasicBlock getRescuerBBFor(BasicBlock block) {
return rescuerMap.get(block);
}
public void addGlobalEnsureBB(BasicBlock geb) {
assert globalEnsureBB == null: "CFG for scope " + getScope() + " already has a global ensure block.";
addBasicBlock(geb);
addEdge(geb, getExitBB(), EdgeType.EXIT);
for (BasicBlock b: getBasicBlocks()) {
if (b != geb && !bbIsProtected(b) && b != getEntryBB()) {
addEdge(b, geb, EdgeType.EXCEPTION);
setRescuerBB(b, geb);
}
}
globalEnsureBB = geb;
}
public void setRescuerBB(BasicBlock block, BasicBlock rescuerBlock) {
rescuerMap.put(block, rescuerBlock);
}
public DirectedGraph<BasicBlock> build(Instr[] instrs) {
Map<Label, List<BasicBlock>> forwardRefs = new HashMap<>();
List<BasicBlock> returnBBs = new ArrayList<>();
List<BasicBlock> exceptionBBs = new ArrayList<>();
Stack<ExceptionRegion> nestedExceptionRegions = new Stack<>();
List<ExceptionRegion> allExceptionRegions = new ArrayList<>();
entryBB = createBB(nestedExceptionRegions);
BasicBlock firstBB = createBB(nestedExceptionRegions);
BasicBlock currBB = firstBB;
BasicBlock newBB;
boolean bbEnded = false;
boolean nextBBIsFallThrough = true;
for (Instr i: instrs) {
Operation iop = i.getOperation();
if (iop == Operation.LABEL) {
Label l = ((LabelInstr) i).getLabel();
newBB = createBB(l, nestedExceptionRegions);
if (nextBBIsFallThrough) graph.addEdge(currBB, newBB, EdgeType.FALL_THROUGH);
currBB = newBB;
bbEnded = false;
nextBBIsFallThrough = true;
List<BasicBlock> frefs = forwardRefs.get(l);
if (frefs != null) {
for (BasicBlock b : frefs) {
graph.addEdge(b, newBB, EdgeType.REGULAR);
}
}
} else if (bbEnded && iop != Operation.EXC_REGION_END) {
newBB = createBB(nestedExceptionRegions);
if (nextBBIsFallThrough) graph.addEdge(currBB, newBB, EdgeType.FALL_THROUGH);
currBB = newBB;
bbEnded = false;
nextBBIsFallThrough = true;
}
if (i instanceof ExceptionRegionStartMarkerInstr) {
ExceptionRegionStartMarkerInstr ersmi = (ExceptionRegionStartMarkerInstr) i;
ExceptionRegion rr = new ExceptionRegion(ersmi.getFirstRescueBlockLabel(), currBB);
rr.addBB(currBB);
allExceptionRegions.add(rr);
if (!nestedExceptionRegions.empty()) {
nestedExceptionRegions.peek().addNestedRegion(rr);
}
nestedExceptionRegions.push(rr);
} else if (i instanceof ExceptionRegionEndMarkerInstr) {
nestedExceptionRegions.pop().setEndBB(currBB);
} else if (iop.endsBasicBlock()) {
bbEnded = true;
currBB.addInstr(i);
Label tgt = null;
nextBBIsFallThrough = false;
if (i instanceof BranchInstr) {
tgt = ((BranchInstr) i).getJumpTarget();
nextBBIsFallThrough = true;
} else if (i instanceof MultiBranchInstr) {
Label[] tgts = ((MultiBranchInstr) i).getJumpTargets();
for (Label l : tgts) addEdge(currBB, l, forwardRefs);
} else if (i instanceof JumpInstr) {
tgt = ((JumpInstr) i).getJumpTarget();
} else if (iop.isReturn()) {
tgt = null;
returnBBs.add(currBB);
} else if (i instanceof ThrowExceptionInstr) {
tgt = null;
exceptionBBs.add(currBB);
} else {
throw new RuntimeException("Unhandled case in CFG builder for basic block ending instr: " + i);
}
if (tgt != null) addEdge(currBB, tgt, forwardRefs);
} else if (iop != Operation.LABEL) {
currBB.addInstr(i);
}
}
for (ExceptionRegion rr: allExceptionRegions) {
Label rescueLabel = rr.getFirstRescueBlockLabel();
if (!Label.UNRESCUED_REGION_LABEL.equals(rescueLabel)) {
BasicBlock firstRescueBB = bbMap.get(rescueLabel);
firstRescueBB.markRescueEntryBB();
for (BasicBlock b: rr.getExclusiveBBs()) {
if (b.canRaiseExceptions()) {
setRescuerBB(b, firstRescueBB);
graph.addEdge(b, firstRescueBB, EdgeType.EXCEPTION);
}
}
}
}
buildExitBasicBlock(nestedExceptionRegions, firstBB, returnBBs, exceptionBBs, nextBBIsFallThrough, currBB, entryBB);
optimize(returnBBs);
return graph;
}
private void addEdge(BasicBlock src, Label targetLabel, Map<Label, List<BasicBlock>> forwardRefs) {
BasicBlock target = bbMap.get(targetLabel);
if (target != null) {
graph.addEdge(src, target, EdgeType.REGULAR);
return;
}
List<BasicBlock> forwardReferences = forwardRefs.get(targetLabel);
if (forwardReferences == null) {
forwardReferences = new ArrayList<>();
forwardRefs.put(targetLabel, forwardReferences);
}
forwardReferences.add(src);
}
private BasicBlock buildExitBasicBlock(Stack<ExceptionRegion> nestedExceptionRegions, BasicBlock firstBB,
List<BasicBlock> returnBBs, List<BasicBlock> exceptionBBs, boolean nextIsFallThrough, BasicBlock currBB, BasicBlock entryBB) {
exitBB = createBB(nestedExceptionRegions);
graph.addEdge(entryBB, exitBB, EdgeType.EXIT);
graph.addEdge(entryBB, firstBB, EdgeType.FALL_THROUGH);
for (BasicBlock rb : returnBBs) {
graph.addEdge(rb, exitBB, EdgeType.EXIT);
}
for (BasicBlock rb : exceptionBBs) {
graph.addEdge(rb, exitBB, EdgeType.EXIT);
}
if (nextIsFallThrough) graph.addEdge(currBB, exitBB, EdgeType.EXIT);
return exitBB;
}
private BasicBlock createBB(Label label, Stack<ExceptionRegion> nestedExceptionRegions) {
BasicBlock basicBlock = new BasicBlock(this, label);
addBasicBlock(basicBlock);
if (label.isGlobalEnsureBlockLabel()) {
globalEnsureBB = basicBlock;
}
if (!nestedExceptionRegions.empty()) nestedExceptionRegions.peek().addBB(basicBlock);
return basicBlock;
}
private BasicBlock createBB(Stack<ExceptionRegion> nestedExceptionRegions) {
return createBB(scope.getNewLabel(), nestedExceptionRegions);
}
public void addBasicBlock(BasicBlock bb) {
graph.findOrCreateVertexFor(bb);
bbMap.put(bb.getLabel(), bb);
postOrderList = null;
}
public void removeAllOutgoingEdgesForBB(BasicBlock b) {
graph.findVertexFor(b).removeAllOutgoingEdges();
}
private void deleteOrphanedBlocks(DirectedGraph<BasicBlock> graph) {
Queue<BasicBlock> worklist = new LinkedList();
Set<BasicBlock> living = new HashSet();
worklist.add(entryBB);
living.add(entryBB);
while (!worklist.isEmpty()) {
BasicBlock current = worklist.remove();
for (BasicBlock bb: graph.findVertexFor(current).getOutgoingDestinationsData()) {
if (!living.contains(bb)) {
worklist.add(bb);
living.add(bb);
}
}
}
Set<BasicBlock> dead = new HashSet();
for (BasicBlock bb: graph.allData()) {
if (!living.contains(bb)) dead.add(bb);
}
for (BasicBlock bb: dead) {
removeBB(bb);
removeNestedScopesFromBB(bb);
}
}
private boolean mergeBBs(BasicBlock a, BasicBlock b) {
BasicBlock aR = getRescuerBBFor(a);
BasicBlock bR = getRescuerBBFor(b);
if (aR == bR || a.isEmpty() || b.isEmpty()) {
Instr lastInstr = a.getLastInstr();
if (lastInstr instanceof JumpInstr) a.removeInstr(lastInstr);
a.swallowBB(b);
removeEdge(a, b);
for (Edge<BasicBlock> e : getOutgoingEdges(b)) {
addEdge(a, e.getDestination().getData(), e.getType());
}
for (Edge<BasicBlock> e : getIncomingEdges(b)) {
BasicBlock fixupBB = e.getSource().getData();
removeEdge(fixupBB, b);
addEdge(fixupBB, a, e.getType());
Instr fixupLastInstr = fixupBB.getLastInstr();
if (fixupLastInstr instanceof JumpTargetInstr) {
((JumpTargetInstr) fixupLastInstr).setJumpTarget(a.getLabel());
}
}
removeBB(b);
if (aR == null && bR != null) {
setRescuerBB(a, bR);
}
return true;
} else {
return false;
}
}
public void removeBB(BasicBlock b) {
if (b == globalEnsureBB) globalEnsureBB = null;
graph.removeVertexFor(b);
bbMap.remove(b.getLabel());
rescuerMap.remove(b);
}
private void removeNestedScopesFromBB(BasicBlock bb) {
for (Instr instr: bb.getInstrs()) {
for (Operand oper: instr.getOperands()) {
if (oper instanceof WrappedIRClosure) {
scope.removeClosure(((WrappedIRClosure) oper).getClosure());
break;
}
}
}
}
public void collapseStraightLineBBs() {
List<BasicBlock> cfgBBs = new ArrayList<>();
for (BasicBlock b: getBasicBlocks()) cfgBBs.add(b);
Set<BasicBlock> mergedBBs = new HashSet<>();
for (BasicBlock b: cfgBBs) {
if (!mergedBBs.contains(b) && outDegree(b) == 1) {
for (Edge<BasicBlock> e : getOutgoingEdges(b)) {
BasicBlock outB = e.getDestination().getData();
if (e.getType() != EdgeType.EXCEPTION && (inDegree(outB) == 1 || b.isEmpty()) && mergeBBs(b, outB)) {
mergedBBs.add(outB);
}
}
}
}
}
public void optimize(List<BasicBlock> returnBBs) {
List<Edge<BasicBlock>> toRemove = new ArrayList<>();
for (BasicBlock retBB: returnBBs) {
List<Instr> rbInstrs = retBB.getInstrs();
Instr first = rbInstrs.get(0);
if (first instanceof ReturnInstr) {
Operand rv = ((ReturnInstr)first).getReturnValue();
if (rv instanceof Variable) {
for (Edge<BasicBlock> e : getIncomingEdges(retBB)) {
BasicBlock srcBB = e.getSource().getData();
List<Instr> srcInstrs = srcBB.getInstrs();
int n = srcInstrs.size();
if (n == 0) continue;
Instr jump = null;
Instr last = srcInstrs.get(n-1);
if (last instanceof JumpInstr && n > 2) {
jump = last;
last = srcInstrs.get(n-2);
}
if (last instanceof CopyInstr && ((CopyInstr)last).getResult().equals(rv)) {
srcInstrs.set(n-1, new ReturnInstr(((CopyInstr)last).getSource()));
toRemove.add(e);
addEdge(srcBB, exitBB, EdgeType.EXIT);
if (jump != null) {
srcInstrs.remove(jump);
}
}
}
}
}
}
for (Edge<BasicBlock> edge: toRemove) {
graph.removeEdge(edge);
}
deleteOrphanedBlocks(graph);
collapseStraightLineBBs();
}
public String toStringGraph() {
return graph.toString();
}
public String toStringInstrs() {
StringBuilder buf = new StringBuilder();
for (BasicBlock b : getSortedBasicBlocks()) {
buf.append(b.toStringInstrs());
}
buf.append("\n\n------ Rescue block map ------\n");
List<BasicBlock> e = new ArrayList<>(rescuerMap.keySet());
Collections.sort(e);
for (BasicBlock bb : e) {
buf.append("BB ").append(bb.getID()).append(" --> BB ").append(rescuerMap.get(bb).getID()).append("\n");
}
return buf.toString();
}
public void removeEdge(BasicBlock a, BasicBlock b) {
graph.removeEdge(a, b);
}
private LinkedList<BasicBlock> buildPostOrderList() {
BasicBlock root = getEntryBB();
LinkedList<BasicBlock> list = new LinkedList<>();
Stack<BasicBlock> stack = new Stack<>();
boolean[] visited = new boolean[1 + getMaxNodeID()];
stack.push(root);
visited[root.getID()] = true;
while (!stack.empty()) {
BasicBlock b = stack.peek();
boolean allChildrenVisited = true;
for (BasicBlock dst: getOutgoingDestinations(b)) {
int dstID = dst.getID();
if (!visited[dstID]) {
allChildrenVisited = false;
if (graph.findVertexFor(dst).outDegree() == 0) {
list.add(dst);
} else {
stack.push(dst);
}
visited[dstID] = true;
}
}
if (allChildrenVisited) {
stack.pop();
list.add(b);
}
}
for (BasicBlock b : getBasicBlocks()) {
if (!visited[b.getID()]) {
printError("BB " + b.getID() + " missing from po list!");
break;
}
}
return list;
}
public CFG clone(CloneInfo info, IRScope clonedScope) {
CFG newCFG = new CFG(clonedScope);
Map<BasicBlock, BasicBlock> cloneBBMap = new HashMap<>();
for (BasicBlock bb: getBasicBlocks()) {
BasicBlock newBB = bb.clone(info, newCFG);
newCFG.addBasicBlock(newBB);
cloneBBMap.put(bb, newBB);
}
for (BasicBlock bb: getBasicBlocks()) {
BasicBlock newSource = cloneBBMap.get(bb);
for (Edge<BasicBlock> edge : getOutgoingEdges(bb)) {
BasicBlock newDestination = cloneBBMap.get(edge.getDestination().getData());
newCFG.addEdge(newSource, newDestination, edge.getType());
}
}
for (BasicBlock bb: rescuerMap.keySet()) {
newCFG.setRescuerBB(cloneBBMap.get(bb), cloneBBMap.get(rescuerMap.get(bb)));
}
newCFG.entryBB = cloneBBMap.get(entryBB);
newCFG.exitBB = cloneBBMap.get(exitBB);
newCFG.globalEnsureBB = cloneBBMap.get(globalEnsureBB);
return newCFG;
}
private void printError(String message) {
LOG.error(message + "\nGraph:\n" + this + "\nInstructions:\n" + toStringInstrs());
}
}