/*
 * Copyright (c) 2016, Oracle and/or its affiliates. All rights reserved.
 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
 *
 * This code is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License version 2 only, as
 * published by the Free Software Foundation.
 *
 * This code is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * version 2 for more details (a copy is included in the LICENSE file that
 * accompanied this code).
 *
 * You should have received a copy of the GNU General Public License version
 * 2 along with this work; if not, write to the Free Software Foundation,
 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
 *
 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
 * or visit www.oracle.com if you need additional information or have any
 * questions.
 */
package org.graalvm.compiler.replacements.test;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;

import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;

import org.graalvm.compiler.core.common.type.IntegerStamp;
import org.graalvm.compiler.core.common.type.StampFactory;
import org.graalvm.compiler.core.test.GraalCompilerTest;
import org.graalvm.compiler.nodes.ConstantNode;
import org.graalvm.compiler.nodes.ParameterNode;
import org.graalvm.compiler.nodes.PiNode;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.ValueNode;
import org.graalvm.compiler.nodes.StructuredGraph.AllowAssumptions;
import org.graalvm.compiler.nodes.calc.MulNode;
import org.graalvm.compiler.nodes.java.StoreFieldNode;
import org.graalvm.compiler.phases.common.CanonicalizerPhase;
import org.graalvm.compiler.phases.tiers.HighTierContext;
import org.graalvm.compiler.replacements.nodes.arithmetic.IntegerMulExactNode;

@RunWith(Parameterized.class)
public class IntegerMulExactFoldTest extends GraalCompilerTest {

    public static int SideEffectI;
    public static long SideEffectL;

    public static void snippetInt(int a, int b) {
        SideEffectI = Math.multiplyExact(a, b);
    }

    public static void snippetLong(long a, long b) {
        SideEffectL = Math.multiplyExact(a, b);
    }

    private StructuredGraph prepareGraph(String snippet) {
        StructuredGraph graph = parseEager(snippet, AllowAssumptions.NO);
        HighTierContext context = getDefaultHighTierContext();
        new CanonicalizerPhase().apply(graph, context);
        return graph;
    }

    @Parameter(0) public long lowerBound1;
    @Parameter(1) public long upperBound1;
    @Parameter(2) public long lowerBound2;
    @Parameter(3) public long upperBound2;
    @Parameter(4) public int bits;

    @Test
    public void tryFold() {
        assert bits == 32 || bits == 64;

        IntegerStamp a = StampFactory.forInteger(bits, lowerBound1, upperBound1);
        IntegerStamp b = StampFactory.forInteger(bits, lowerBound2, upperBound2);

        // prepare the graph once for the given stamps, if the canonicalize method thinks it does
        // not overflow it will replace the exact mul with a normal mul
        StructuredGraph g = prepareGraph(bits == 32 ? "snippetInt" : "snippetLong");
        List<ParameterNode> params = g.getNodes(ParameterNode.TYPE).snapshot();
        params.get(0).replaceAtMatchingUsages((g.addOrUnique(new PiNode(params.get(0), a))), x -> x instanceof IntegerMulExactNode);
        params.get(1).replaceAtMatchingUsages((g.addOrUnique(new PiNode(params.get(1), b))), x -> x instanceof IntegerMulExactNode);
        new CanonicalizerPhase().apply(g, getDefaultHighTierContext());
        boolean optimized = g.getNodes().filter(IntegerMulExactNode.class).count() == 0;
        ValueNode leftOverMull = optimized ? g.getNodes().filter(MulNode.class).first() : g.getNodes().filter(IntegerMulExactNode.class).first();
        new CanonicalizerPhase().apply(g, getDefaultHighTierContext());
        if (leftOverMull == null) {
            // result may be constant if there is no mul exact or mul node left
            leftOverMull = g.getNodes().filter(StoreFieldNode.class).first().inputs().filter(ConstantNode.class).first();
        }
        if (leftOverMull == null) {
            // even mul got canonicalized so we may end up with one of the original nodes
            leftOverMull = g.getNodes().filter(PiNode.class).first();
        }
        IntegerStamp resultStamp = (IntegerStamp) leftOverMull.stamp();

        // now check for all values in the stamp whether their products overflow overflow
        for (long l1 = lowerBound1; l1 <= upperBound1; l1++) {
            for (long l2 = lowerBound2; l2 <= upperBound2; l2++) {
                try {
                    long res = mulExact(l1, l2, bits);
                    Assert.assertTrue(resultStamp.contains(res));
                } catch (ArithmeticException e) {
                    Assert.assertFalse(optimized);
                }
                if (l2 == Long.MAX_VALUE) {
                    // do not want to overflow the check loop
                    break;
                }
            }
            if (l1 == Long.MAX_VALUE) {
                // do not want to overflow the check loop
                break;
            }
        }

    }

    private static long mulExact(long x, long y, int bits) {
        long r = x * y;
        if (bits == 8) {
            if ((byte) r != r) {
                throw new ArithmeticException("overflow");
            }
        } else if (bits == 16) {
            if ((short) r != r) {
                throw new ArithmeticException("overflow");
            }
        } else if (bits == 32) {
            return Math.multiplyExact((int) x, (int) y);
        } else {
            return Math.multiplyExact(x, y);
        }
        return r;
    }

    @Parameters(name = "a[{0} - {1}] b[{2} - {3}] bits=32")
    public static Collection<Object[]> data() {
        ArrayList<Object[]> tests = new ArrayList<>();

        // zero related
        addTest(tests, -2, 2, 3, 3, 32);
        addTest(tests, 0, 0, 1, 1, 32);
        addTest(tests, 1, 1, 0, 0, 32);
        addTest(tests, -1, 1, 0, 1, 32);
        addTest(tests, -1, 1, 1, 1, 32);
        addTest(tests, -1, 1, -1, 1, 32);

        addTest(tests, -2, 2, 3, 3, 64);
        addTest(tests, 0, 0, 1, 1, 64);
        addTest(tests, 1, 1, 0, 0, 64);
        addTest(tests, -1, 1, 0, 1, 64);
        addTest(tests, -1, 1, 1, 1, 64);
        addTest(tests, -1, 1, -1, 1, 64);

        addTest(tests, -2, 2, 3, 3, 32);
        addTest(tests, 0, 0, 1, 1, 32);
        addTest(tests, 1, 1, 0, 0, 32);
        addTest(tests, -1, 1, 0, 1, 32);
        addTest(tests, -1, 1, 1, 1, 32);
        addTest(tests, -1, 1, -1, 1, 32);

        addTest(tests, 0, 0, 1, 1, 64);
        addTest(tests, 1, 1, 0, 0, 64);
        addTest(tests, -1, 1, 0, 1, 64);
        addTest(tests, -1, 1, 1, 1, 64);
        addTest(tests, -1, 1, -1, 1, 64);

        // bounds
        addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFF, Integer.MAX_VALUE - 0xFF,
                        Integer.MAX_VALUE, 32);
        addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFFF, -1, -1, 32);
        addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFF, Integer.MAX_VALUE - 0xFF,
                        Integer.MAX_VALUE, 64);
        addTest(tests, Integer.MIN_VALUE, Integer.MIN_VALUE + 0xFFF, -1, -1, 64);
        addTest(tests, Long.MIN_VALUE, Long.MIN_VALUE + 0xFFF, -1, -1, 64);

        // constants
        addTest(tests, 2, 2, 2, 2, 32);
        addTest(tests, 1, 1, 2, 2, 32);
        addTest(tests, 2, 2, 4, 4, 32);
        addTest(tests, 3, 3, 3, 3, 32);
        addTest(tests, -4, -4, 3, 3, 32);
        addTest(tests, -4, -4, -3, -3, 32);
        addTest(tests, 4, 4, -3, -3, 32);

        addTest(tests, 2, 2, 2, 2, 64);
        addTest(tests, 1, 1, 2, 2, 64);
        addTest(tests, 3, 3, 3, 3, 64);

        addTest(tests, Long.MAX_VALUE, Long.MAX_VALUE, 1, 1, 64);
        addTest(tests, Long.MAX_VALUE, Long.MAX_VALUE, -1, -1, 64);
        addTest(tests, Long.MIN_VALUE, Long.MIN_VALUE, -1, -1, 64);
        addTest(tests, Long.MIN_VALUE, Long.MIN_VALUE, 1, 1, 64);

        return tests;
    }

    private static void addTest(ArrayList<Object[]> tests, long lowerBound1, long upperBound1, long lowerBound2, long upperBound2, int bits) {
        tests.add(new Object[]{lowerBound1, upperBound1, lowerBound2, upperBound2, bits});
    }

}