/*
 **** BEGIN LICENSE BLOCK *****
 * Version: EPL 2.0/GPL 2.0/LGPL 2.1
 *
 * The contents of this file are subject to the Eclipse Public
 * License Version 2.0 (the "License"); you may not use this file
 * except in compliance with the License. You may obtain a copy of
 * the License at http://www.eclipse.org/legal/epl-v20.html
 *
 * Software distributed under the License is distributed on an "AS
 * IS" basis, WITHOUT WARRANTY OF ANY KIND, either express or
 * implied. See the License for the specific language governing
 * rights and limitations under the License.
 *
 * Alternatively, the contents of this file may be used under the terms of
 * either of the GNU General Public License Version 2 or later (the "GPL"),
 * or the GNU Lesser General Public License Version 2.1 or later (the "LGPL"),
 * in which case the provisions of the GPL or the LGPL are applicable instead
 * of those above. If you wish to allow use of your version of this file only
 * under the terms of either the GPL or the LGPL, and not to allow others to
 * use your version of this file under the terms of the EPL, indicate your
 * decision by deleting the provisions above and replace them with the notice
 * and other provisions required by the GPL or the LGPL. If you do not delete
 * the provisions above, a recipient may use your version of this file under
 * the terms of any one of the EPL, the GPL or the LGPL.
 ***** END LICENSE BLOCK *****/

package org.jruby.ext.zlib;

import com.jcraft.jzlib.JZlib;
import org.jruby.Ruby;
import org.jruby.RubyClass;
import org.jruby.RubyNumeric;
import org.jruby.RubyString;
import org.jruby.anno.JRubyClass;
import org.jruby.anno.JRubyMethod;
import org.jruby.runtime.Block;
import org.jruby.runtime.ObjectAllocator;
import org.jruby.runtime.ThreadContext;
import static org.jruby.runtime.Visibility.PRIVATE;
import org.jruby.runtime.builtin.IRubyObject;
import org.jruby.util.ByteList;

 @JRubyClass(name = "Zlib::Inflate", parent = "Zlib::ZStream")
public class JZlibInflate extends ZStream {

    public static final int BASE_SIZE = 100;
    private int windowBits;
    private byte[] collected;
    private int collectedIdx;
    private ByteList input;
    private com.jcraft.jzlib.Inflater flater = null;
    protected static final ObjectAllocator INFLATE_ALLOCATOR = new ObjectAllocator() {

        public IRubyObject allocate(Ruby runtime, RubyClass klass) {
            return new JZlibInflate(runtime, klass);
        }
    };

    public JZlibInflate(Ruby runtime, RubyClass type) {
        super(runtime, type);
    }

    @JRubyMethod(name = "inflate", required = 1, meta = true)
    public static IRubyObject s_inflate(ThreadContext context, IRubyObject recv, IRubyObject string) {
        RubyClass klass = (RubyClass)(recv.isClass() ? recv : context.runtime.getClassFromPath("Zlib::Inflate"));
        JZlibInflate inflate = (JZlibInflate) klass.allocate();
        inflate.init(JZlib.DEF_WBITS);

        IRubyObject result;
        try {
            inflate.append(string.convertToString().getByteList());
        } finally {
            result = inflate.finish(context, Block.NULL_BLOCK);
            inflate.close();
        }
        return result;
    }

    @JRubyMethod(name = "initialize", optional = 1, visibility = PRIVATE)
    public IRubyObject _initialize(IRubyObject[] args) {
        windowBits = JZlib.DEF_WBITS;

        if (args.length > 0 && !args[0].isNil()) {
            windowBits = RubyNumeric.fix2int(args[0]);
            checkWindowBits(getRuntime(), windowBits, true);
        }

        init(windowBits);
        return this;
    }

    private void init(int windowBits) {
        flater = new com.jcraft.jzlib.Inflater();
        flater.init(windowBits);
        collected = new byte[BASE_SIZE];
        collectedIdx = 0;
        input = new ByteList();
    }

    @Override
    @JRubyMethod(name = "flush_next_out")
    public IRubyObject flush_next_out(ThreadContext context, Block block) {
        return flushOutput(context, block);
    }

    private IRubyObject flushOutput(ThreadContext context, Block block) {
        if (collectedIdx > 0) {
            RubyString res = RubyString.newString(context.runtime, collected, 0, collectedIdx);
            collectedIdx = 0;
            flater.setOutput(collected);

            if (block.isGiven()) {
                block.yield(context, res);
                return context.nil;
            }
            return res;
        }
        return RubyString.newEmptyString(context.runtime);
    }

    @JRubyMethod(name = "<<", required = 1)
    public IRubyObject append(ThreadContext context, IRubyObject arg) {
        checkClosed();
        if (arg.isNil()) {
            run(true);
        } else {
            append(arg.convertToString().getByteList());
        }
        return this;
    }

    public void append(ByteList obj) {
        if (!internalFinished()) {
            flater.setInput(obj.bytes(), true);
        } else {
            input.append(obj);
        }
        run(false);
    }

    @JRubyMethod(name = "sync_point?")
    public IRubyObject sync_point_p() {
        return sync_point();
    }

    public IRubyObject sync_point() {
        int ret = flater.syncPoint();
        switch (ret) {
            case 1:
                return getRuntime().getTrue();
            case com.jcraft.jzlib.JZlib.Z_DATA_ERROR:
                throw RubyZlib.newStreamError(getRuntime(), "stream error");
            default:
                return getRuntime().getFalse();
        }
    }

    @JRubyMethod(name = "set_dictionary", required = 1)
    public IRubyObject set_dictionary(ThreadContext context, IRubyObject arg) {
        try {
            return set_dictionary(arg);
        } catch (IllegalArgumentException iae) {
            throw RubyZlib.newStreamError(context.getRuntime(), "stream error: " + iae.getMessage());
        }
    }

    private IRubyObject set_dictionary(IRubyObject str) {
        byte[] tmp = str.convertToString().getBytes();
        int ret = flater.setDictionary(tmp, tmp.length);
        switch (ret) {
            case com.jcraft.jzlib.JZlib.Z_STREAM_ERROR:
                throw RubyZlib.newStreamError(getRuntime(), "stream error");
            case com.jcraft.jzlib.JZlib.Z_DATA_ERROR:
                throw RubyZlib.newDataError(getRuntime(), "wrong dictionary");
            default:
        }
        run(false);
        return str;
    }

    @JRubyMethod(name = "inflate", required = 1)
    public IRubyObject inflate(ThreadContext context, IRubyObject string, Block block) {
        ByteList data = null;
        if (!string.isNil()) {
            data = string.convertToString().getByteList();
        }
        return inflate(context, data, block);
    }

    public IRubyObject inflate(ThreadContext context, ByteList str, Block block) {
        if (null == str) {
            return internalFinish(block);
        } else {
            append(str);
            return flushOutput(context, block);
        }
    }

    @JRubyMethod(name = "sync", required = 1)
    public IRubyObject sync(ThreadContext context, IRubyObject string) {
        if (flater.avail_in > 0) {
            switch (flater.sync()) {
                case com.jcraft.jzlib.JZlib.Z_OK:
                    flater.setInput(string.convertToString().getByteList().bytes(),
                            true);
                    return getRuntime().getTrue();
                case com.jcraft.jzlib.JZlib.Z_DATA_ERROR:
                    break;
                default:
                    throw RubyZlib.newStreamError(getRuntime(), "stream error");
            }
        }
        if (string.convertToString().getByteList().length() <= 0) {
            return getRuntime().getFalse();
        }
        flater.setInput(string.convertToString().getByteList().bytes(), true);
        switch (flater.sync()) {
            case com.jcraft.jzlib.JZlib.Z_OK:
                return getRuntime().getTrue();
            case com.jcraft.jzlib.JZlib.Z_DATA_ERROR:
                return getRuntime().getFalse();
            default:
                throw RubyZlib.newStreamError(getRuntime(), "stream error");
        }
    }

    private void run(boolean finish) {
        int resultLength = -1;
        Ruby runtime = getRuntime();

        while (!internalFinished() && resultLength != 0) {
            // MRI behavior
            boolean needsInput = flater.avail_in < 0;
            if (finish && needsInput) {
                throw RubyZlib.newBufError(runtime, "buffer error");
            }

            flater.setOutput(collected, collectedIdx, collected.length - collectedIdx);
            int ret = flater.inflate(com.jcraft.jzlib.JZlib.Z_NO_FLUSH);
            resultLength = flater.next_out_index - collectedIdx;
            collectedIdx = flater.next_out_index;
            switch (ret) {
                case com.jcraft.jzlib.JZlib.Z_DATA_ERROR:
                    /*
                     * resultLength = flater.next_out_index; if(resultLength>0){
                     * // error has been occurred, // but some data has been
                     * inflated successfully. collected.append(outp, 0,
                     * resultLength); }
                     */
                    throw RubyZlib.newDataError(runtime, flater.getMessage());
                case com.jcraft.jzlib.JZlib.Z_NEED_DICT:
                    throw RubyZlib.newDictError(runtime, "need dictionary");
                case com.jcraft.jzlib.JZlib.Z_STREAM_END:
                    if (flater.avail_in > 0) {
                        // MRI behavior: pass-through
                        input.append(flater.next_in,
                                flater.next_in_index, flater.avail_in);
                        flater.setInput("".getBytes());
                    }
                case com.jcraft.jzlib.JZlib.Z_OK:
                    resultLength = flater.next_out_index;
                    break;
                default:
                    resultLength = 0;
            }
            if (collected.length == collectedIdx && !internalFinished()) {
                byte[] tmp = new byte[collected.length * 3];
                System.arraycopy(collected, 0, tmp, 0, collected.length);
                collected = tmp;
            }
        }
        if (finish) {
            if (!internalFinished()) {
                int err = flater.inflate(com.jcraft.jzlib.JZlib.Z_FINISH);
                if (err != com.jcraft.jzlib.JZlib.Z_OK) {
                    throw RubyZlib.newBufError(getRuntime(), "buffer error");
                }
            }
        }
    }

    @Override
    protected int internalTotalIn() {
        return (int) flater.total_in;
    }

    @Override
    protected int internalTotalOut() {
        return (int) flater.total_out;
    }

    @Override
    protected boolean internalStreamEndP() {
        return flater.finished();
    }

    @Override
    protected void internalReset() {
        init(windowBits);
    }

    @Override
    protected boolean internalFinished() {
        return flater.finished();
    }

    @Override
    protected long internalAdler() {
        return flater.getAdler();
    }

    @Override
    protected IRubyObject internalFinish(Block block) {
        run(true);
        // MRI behavior: in finished mode, we work as pass-through
        if (internalFinished()) {
            if (input.getRealSize() > 0) {
                if (collected.length - collectedIdx < input.length()) {
                    byte[] tmp = new byte[collected.length + input.length()];
                    System.arraycopy(collected, 0, tmp, 0, collectedIdx);
                    collected = tmp;
                }
                System.arraycopy(input.getUnsafeBytes(), input.begin(), collected, collectedIdx, input.length());
                collectedIdx += input.length();
                resetBuffer(input);
            }
        }
        return flushOutput(getRuntime().getCurrentContext(), block);
    }

    @Override
    protected void internalClose() {
        flater.end();
    }

    @Override
    public IRubyObject avail_in() {
        return getRuntime().newFixnum(flater.avail_in);
    }

    private static void resetBuffer(ByteList l) {
        l.setBegin(0);
        l.setRealSize(0);
        l.invalidate();
    }
}