package org.jruby.ir.interpreter;

import org.jruby.RubyModule;
import org.jruby.internal.runtime.methods.MixedModeIRMethod;
import org.jruby.ir.*;
import org.jruby.ir.instructions.CallBase;
import org.jruby.ir.instructions.Instr;
import org.jruby.ir.operands.Operand;
import org.jruby.ir.operands.WrappedIRClosure;
import org.jruby.runtime.CallSite;
import org.jruby.runtime.callsite.CacheEntry;
import org.jruby.runtime.callsite.CachingCallSite;

import java.util.*;

public class Profiler {
    private static class IRCallSite {
        IRScope  s;
        int      v; // scope version
        CallBase call;
        long     count;
        MixedModeIRMethod tgtM;

        public IRCallSite() {}

        public IRCallSite(IRCallSite cs) {
            this.s     = cs.s;
            this.v     = cs.v;
            this.call  = cs.call;
            this.count = 0;
        }

        public int hashCode() {
            return (int)this.call.callSiteId;
        }
    }

    private static class CallSiteProfile {
        IRCallSite cs;
        HashMap<IRScope, Counter> counters;

        public CallSiteProfile(IRCallSite cs) {
            this.cs = new IRCallSite(cs);
            this.counters = new HashMap<IRScope,Counter>();
        }
    }

    private static IRCallSite callerSite = new IRCallSite();

    private static int inlineCount = 0;
    private static int globalThreadPollCount = 0;
    private static int codeModificationsCount = 0;
    private static int numCyclesWithNoModifications = 0;
    private static int versionCount = 1;

    private static HashMap<IRScope, Integer> scopeVersionMap = new HashMap<IRScope, Integer>();
    private static HashMap<IRScope, Counter> scopeThreadPollCounts = new HashMap<IRScope, Counter>();
    private static HashMap<Long, CallSiteProfile> callProfile = new HashMap<Long, CallSiteProfile>();
    private static HashMap<Operation, Counter> opStats = new HashMap<Operation, Counter>();

    private static void analyzeProfile() {
        versionCount++;

        //if (inlineCount == 2) return;

        if (codeModificationsCount == 0) numCyclesWithNoModifications++;
        else numCyclesWithNoModifications = 0;

        codeModificationsCount = 0;

        if (numCyclesWithNoModifications < 3) return;

        // We are now good to go -- start analyzing the profile

        // System.out.println("-------------------start analysis-----------------------");

        final HashMap<IRScope, Long> scopeCounts = new HashMap<IRScope, Long>();
        final ArrayList<IRCallSite> callSites = new ArrayList<IRCallSite>();
        //HashMap<IRCallSite, Long> callSiteCounts = new HashMap<IRCallSite, Long>();
        // System.out.println("# call sites: " + callProfile.keySet().size());
        long total = 0;
        for (Long id: callProfile.keySet()) {
            Long c;

            CallSiteProfile csp = callProfile.get(id);
            IRCallSite      cs  = csp.cs;

            if (cs.v != scopeVersionMap.get(cs.s).intValue()) {
                System.out.println("Skipping callsite: <" + cs.s + "," + cs.v + "> with compiled version: " + scopeVersionMap.get(cs.s));
                continue;
            }

            Set<IRScope> calledScopes = csp.counters.keySet();
            cs.count = 0;
            for (IRScope s: calledScopes) {
                c = scopeCounts.get(s);
                if (c == null) {
                    c = Long.valueOf(0);
                    scopeCounts.put(s, c);
                }

                long x = csp.counters.get(s).count;
                c += x;
                cs.count += x;
            }

            CallBase call = cs.call;
            if (calledScopes.size() == 1 && !call.inliningBlocked()) {
                CallSite runtimeCS = call.getCallSite();
                if (runtimeCS != null && (runtimeCS instanceof CachingCallSite)) {
                    CachingCallSite ccs = (CachingCallSite)runtimeCS;
                    CacheEntry ce = ccs.getCache();

                    if (!(ce.method instanceof MixedModeIRMethod)) {
                        // System.out.println("NOT IR-M!");
                        continue;
                    } else {
                        callSites.add(cs);
                        cs.tgtM = (MixedModeIRMethod)ce.method;
                    }
                }
            }

            total += cs.count;
        }

        Collections.sort(callSites, new java.util.Comparator<IRCallSite> () {
            @Override
            public int compare(IRCallSite a, IRCallSite b) {
                if (a.count == b.count) return 0;
                return (a.count < b.count) ? 1 : -1;
            }
        });

        // Find top N call sites
        double freq = 0.0;
        int i = 0;
        boolean noInlining = true;
        Set<IRScope> inlinedScopes = new HashSet<IRScope>();
        for (IRCallSite ircs: callSites) {
            double contrib = (ircs.count*100.0)/total;

            // 1% is arbitrary
            if (contrib < 1.0) break;

            i++;
            freq += contrib;

            // This check is arbitrary
            if (i == 100 || freq > 99.0) break;

            System.out.println("Considering: " + ircs.call + " with id: " + ircs.call.callSiteId +
            " in scope " + ircs.s + " with count " + ircs.count + "; contrib " + contrib + "; freq: " + freq);

            // Now inline here!
            CallBase call = ircs.call;

            IRScope hs = ircs.s;
            boolean isHotClosure = hs instanceof IRClosure;
            IRScope hc = isHotClosure ? hs : null;
            hs = isHotClosure ? hs.getLexicalParent() : hs;

            IRScope tgtMethod = ircs.tgtM.getIRScope();

            Instr[] instrs = tgtMethod.getInterpreterContext().getInstructions();
            // Dont inline large methods -- 500 is arbitrary
            // Can be null if a previously inlined method hasn't been rebuilt
            if ((instrs == null) || instrs.length > 500) {
                if (instrs == null) System.out.println("no instrs!");
                else System.out.println("large method with " + instrs.length + " instrs. skipping!");
                continue;
            }

            RubyModule implClass = ircs.tgtM.getImplementationClass();
            int classToken = implClass.getGeneration();
            String n = tgtMethod.getId();
            boolean inlineCall = true;
            if (isHotClosure) {
                Operand clArg = call.getClosureArg(null);
                inlineCall = (clArg instanceof WrappedIRClosure) && (((WrappedIRClosure)clArg).getClosure() == hc);
            }
/*
            if (inlineCall) {
                noInlining = false;
                long start = new java.util.Date().getTime();
                hs.inlineMethod(tgtMethod, implClass, classToken, null, call, !inlinedScopes.contains(hs));
                inlinedScopes.add(hs);
                long end = new java.util.Date().getTime();
                // System.out.println("Inlined " + tgtMethod + " in " + hs +
                //     " @ instr " + call + " in time (ms): "
                //     + (end-start) + " # instrs: " + instrs.length);

                inlineCount++;
            } else {
                //System.out.println("--no inlining--");
            }
*/
        }

        for (IRScope x: inlinedScopes) {
            // Update version count for inlined scopes
            scopeVersionMap.put(x, versionCount);
            // System.out.println("Updating version of " + x + " to " + versionCount);
            //System.out.println("--- pre-inline-instrs ---");
            //System.out.println(x.getCFG().toStringInstrs());
            //System.out.println("--- post-inline-instrs ---");
            //System.out.println(x.getCFG().toStringInstrs());
        }

        // Reset
        codeModificationsCount = 0;
        callProfile = new HashMap<Long, CallSiteProfile>();

        // Every 1M thread polls, discard stats
        if (globalThreadPollCount % 1000000 == 0)  {
            globalThreadPollCount = 0;
        }
    }

    private static void outputProfileStats() {
        ArrayList<IRScope> scopes = new ArrayList<IRScope>(scopeThreadPollCounts.keySet());
        // ENEBO: Removes when getThreadPollInstrsCount()...when profiling revisited we need to find good metric.
        //   I am unsure total count of threadpoll has much correlation to # visited since one may be on back end of
        //   a single loop and all the rest are never actually visited.
        /*
        Collections.sort(scopes, new java.util.Comparator<IRScope> () {
            @Override
            public int compare(IRScope a, IRScope b) {
                // In non-methods and non-closures, we may not have any thread poll instrs.
                int aden = a.getThreadPollInstrsCount();
                if (aden == 0) aden = 1;
                int bden = b.getThreadPollInstrsCount();
                if (bden == 0) bden = 1;

                // Use estimated instr count to order scopes -- rather than raw thread-poll count
                float aCount = scopeThreadPollCounts.get(a).count * (1.0f * a.getInterpreterContext().getInstructions().length/aden);
                float bCount = scopeThreadPollCounts.get(b).count * (1.0f * b.getInterpreterContext().getInstructions().length/bden);
                if (aCount == bCount) return 0;
                return (aCount < bCount) ? 1 : -1;
            }
        });*/


        /*
        LOG.info("------------------------");
        LOG.info("Stats after " + globalThreadPollCount + " thread polls:");
        LOG.info("------------------------");
        LOG.info("# instructions: " + interpInstrsCount);
        LOG.info("# code modifications in this period : " + codeModificationsCount);
        LOG.info("------------------------");
        */
        int i = 0;
        float f1 = 0.0f;
        for (IRScope s: scopes) {
            long n = scopeThreadPollCounts.get(s).count;
            float p1 =  ((n*1000)/globalThreadPollCount)/10.0f;
            String msg = i + ". " + s + " [file:" + s.getFile() + ":" + s.getLine() + "] = " + n + "; (" + p1 + "%)";
            if (s instanceof IRClosure) {
                IRMethod m = s.getNearestMethod();
                //if (m != null) LOG.info(msg + " -- nearest enclosing method: " + m);
                //else LOG.info(msg + " -- no enclosing method --");
            } else {
                //LOG.info(msg);
            }

            i++;
            f1 += p1;

            // Top 20 or those that account for 95% of thread poll events.
            if (i == 20 || f1 >= 95.0) break;
        }

        // reset code modification counter
        codeModificationsCount = 0;

        // Every 1M thread polls, discard stats by reallocating the thread-poll count map
         if (globalThreadPollCount % 1000000 == 0)  {
            //System.out.println("---- resetting thread-poll counters ----");
            scopeThreadPollCounts = new HashMap<IRScope, Counter>();
            globalThreadPollCount = 0;
        }
    }

    public static Integer initProfiling(IRScope scope) {
        if (scope == null) return null;

        /* SSS: Not being used currently
        tpCount = scopeThreadPollCounts.get(scope);
        if (tpCount == null) {
            tpCount = new Counter();
            scopeThreadPollCounts.put(scope, tpCount);
        }
        */

        Integer scopeVersion = scopeVersionMap.get(scope);
        if (scopeVersion == null) {
            scopeVersionMap.put(scope, versionCount);
            scopeVersion = new Integer(versionCount);
        }

        if (callerSite.call != null) {
            Long id = callerSite.call.callSiteId;
            CallSiteProfile csp = callProfile.get(id);
            if (csp == null) {
                csp = new CallSiteProfile(callerSite);
                callProfile.put(id, csp);
            }

            Counter csCount = csp.counters.get(scope);
            if (csCount == null) {
                csCount = new Counter();
                csp.counters.put(scope, csCount);
            }
            csCount.count++;
        }

        return scopeVersion;
    }

    public static void updateCallSite(Instr instr, IRScope scope, Integer scopeVersion) {
        if (scope == null) return;

        if (instr instanceof CallBase) {
            callerSite.s = scope;
            callerSite.v = scopeVersion;
            callerSite.call = (CallBase)instr;
        }
    }

    public static void clockTick() {
        // tpCount.count++; // SSS: Not being used currently
        globalThreadPollCount++;

        // 20K is arbitrary
        // Every 20K profile counts, spit out profile stats
        if (globalThreadPollCount % 20000 == 0) {
            analyzeProfile();
            // outputProfileStats();
        }
    }

    public static void instrTick(Operation operation) {
        if (operation.modifiesCode()) codeModificationsCount++;
        /*
        Counter cnt = opStats.get(operation);
        if (cnt == null) {
            cnt = new Counter();
            opStats.put(operation, cnt);
        }
        cnt.count++;
        */
    }
}