/*
 * Copyright Terracotta, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.ehcache.impl.serialization;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.ObjectStreamClass;
import java.io.OutputStream;
import java.lang.ref.WeakReference;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

import org.ehcache.spi.persistence.StateHolder;
import org.ehcache.spi.persistence.StateRepository;
import org.ehcache.spi.serialization.SerializerException;
import org.ehcache.core.util.ByteBufferInputStream;
import org.ehcache.spi.serialization.Serializer;
import org.ehcache.spi.serialization.StatefulSerializer;

import static java.lang.Math.max;

A trivially compressed Java serialization based serializer.

Class descriptors in the resultant bytes are encoded as integers. Mappings between the integer representation and the ObjectStreamClass, and the Class and the integer representation are stored in a single on-heap map.

/** * A trivially compressed Java serialization based serializer. * <p> * Class descriptors in the resultant bytes are encoded as integers. Mappings * between the integer representation and the {@link ObjectStreamClass}, and the * {@code Class} and the integer representation are stored in a single on-heap * map. */
public class CompactJavaSerializer<T> implements StatefulSerializer<T> { private volatile StateHolder<Integer, ObjectStreamClass> persistentState; private final ConcurrentMap<Integer, ObjectStreamClass> readLookupCache = new ConcurrentHashMap<>(); private final ConcurrentMap<SerializableDataKey, Integer> writeLookupCache = new ConcurrentHashMap<>(); private final Lock lock = new ReentrantLock(); private int nextStreamIndex = 0; private boolean potentiallyInconsistent; private final transient ClassLoader loader;
Constructor to enable this serializer as a transient one.
Params:
  • loader – the classloader to use
See Also:
/** * Constructor to enable this serializer as a transient one. * * @param loader the classloader to use * * @see Serializer */
public CompactJavaSerializer(ClassLoader loader) { this.loader = loader; } @SuppressWarnings("unchecked") public static <T> Class<? extends Serializer<T>> asTypedSerializer() { return (Class) CompactJavaSerializer.class; } @Override public void init(final StateRepository stateRepository) { this.persistentState = stateRepository.getPersistentStateHolder("CompactJavaSerializer-ObjectStreamClassIndex", Integer.class, ObjectStreamClass.class, c -> true, null); refreshMappingsFromStateRepository(); }
{@inheritDoc}
/** * {@inheritDoc} */
@Override public ByteBuffer serialize(T object) throws SerializerException { try { ByteArrayOutputStream bout = new ByteArrayOutputStream(); try (ObjectOutputStream oout = getObjectOutputStream(bout)) { oout.writeObject(object); } return ByteBuffer.wrap(bout.toByteArray()); } catch (IOException e) { throw new SerializerException(e); } }
{@inheritDoc}
/** * {@inheritDoc} */
@Override public T read(ByteBuffer binary) throws ClassNotFoundException, SerializerException { try { try (ObjectInputStream oin = getObjectInputStream(new ByteBufferInputStream(binary))) { @SuppressWarnings("unchecked") T value = (T) oin.readObject(); return value; } } catch (IOException e) { throw new SerializerException(e); } } private ObjectOutputStream getObjectOutputStream(OutputStream out) throws IOException { return new OOS(out); } private ObjectInputStream getObjectInputStream(InputStream input) throws IOException { return new OIS(input, loader); }
{@inheritDoc}
/** * {@inheritDoc} */
@Override public boolean equals(T object, ByteBuffer binary) throws ClassNotFoundException, SerializerException { return object.equals(read(binary)); } private int getOrAddMapping(ObjectStreamClass desc) { SerializableDataKey probe = new SerializableDataKey(desc, false); Integer rep = writeLookupCache.get(probe); if (rep == null) { return addMappingUnderLock(desc, probe); } else { return rep; } } private int addMappingUnderLock(ObjectStreamClass desc, SerializableDataKey probe) { lock.lock(); try { if (potentiallyInconsistent) { refreshMappingsFromStateRepository(); potentiallyInconsistent = false; } while (true) { Integer rep = writeLookupCache.get(probe); if (rep != null) { return rep; } rep = nextStreamIndex++; try { ObjectStreamClass disconnected = disconnect(desc); ObjectStreamClass existingOsc = persistentState.putIfAbsent(rep, disconnected); if (existingOsc == null) { cacheMapping(rep, disconnected); return rep; } else { cacheMapping(rep, disconnect(existingOsc)); } } catch (Throwable t) { potentiallyInconsistent = true; throw t; } } } finally { lock.unlock(); } } private void refreshMappingsFromStateRepository() { int highestIndex = -1; for (Entry<Integer, ObjectStreamClass> entry : persistentState.entrySet()) { Integer index = entry.getKey(); cacheMapping(entry.getKey(), disconnect(entry.getValue())); highestIndex = max(highestIndex, index); } nextStreamIndex = highestIndex + 1; } private void cacheMapping(Integer index, ObjectStreamClass disconnectedOsc) { readLookupCache.merge(index, disconnectedOsc, (existing, update) -> { if (equals(existing, update)) { return existing; } else { throw new AssertionError("Corrupted data:\n" + "State Repository: " + persistentState + "\n" + "Local Write Lookup: " + writeLookupCache + "\n" + "Local Read Lookup: " + readLookupCache); } }); writeLookupCache.merge(new SerializableDataKey(disconnectedOsc, true), index, (existing, update) -> { if (existing.equals(update)) { return existing; } else { throw new AssertionError("Corrupted data:\n" + "State Repository: " + persistentState + "\n" + "Local Write Lookup: " + writeLookupCache + "\n" + "Local Read Lookup: " + readLookupCache); } }); } class OOS extends ObjectOutputStream { OOS(OutputStream out) throws IOException { super(out); } @Override protected void writeClassDescriptor(final ObjectStreamClass desc) throws IOException { writeInt(getOrAddMapping(desc)); } } class OIS extends ObjectInputStream { private final ClassLoader loader; OIS(InputStream in, ClassLoader loader) throws IOException { super(in); this.loader = loader; } @Override protected ObjectStreamClass readClassDescriptor() throws IOException { int key = readInt(); ObjectStreamClass objectStreamClass = readLookupCache.get(key); if (objectStreamClass == null) { objectStreamClass = persistentState.get(key); cacheMapping(key, disconnect(objectStreamClass)); } return objectStreamClass; } @Override protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { try { final ClassLoader cl = loader == null ? Thread.currentThread().getContextClassLoader() : loader; if (cl == null) { return super.resolveClass(desc); } else { try { return Class.forName(desc.getName(), false, cl); } catch (ClassNotFoundException e) { return super.resolveClass(desc); } } } catch (SecurityException ex) { return super.resolveClass(desc); } } } private static class SerializableDataKey { private final ObjectStreamClass osc; private final int hashCode; private transient WeakReference<Class<?>> klazz; SerializableDataKey(ObjectStreamClass desc, boolean store) { Class<?> forClass = desc.forClass(); if (forClass != null) { if (store) { throw new AssertionError("Must not store ObjectStreamClass instances with strong references to classes"); } else if (ObjectStreamClass.lookup(forClass) == desc) { this.klazz = new WeakReference<>(forClass); } } this.hashCode = (3 * desc.getName().hashCode()) ^ (7 * (int) (desc.getSerialVersionUID() >>> 32)) ^ (11 * (int) desc.getSerialVersionUID()); this.osc = desc; } @Override public boolean equals(Object o) { if (o instanceof SerializableDataKey) { return CompactJavaSerializer.equals(this, (SerializableDataKey) o); } else { return false; } } @Override public int hashCode() { return hashCode; } Class<?> forClass() { if (klazz == null) { return null; } else { return klazz.get(); } } public void setClass(Class<?> clazz) { klazz = new WeakReference<>(clazz); } ObjectStreamClass getObjectStreamClass() { return osc; } } private static boolean equals(SerializableDataKey k1, SerializableDataKey k2) { Class<?> k1Clazz = k1.forClass(); Class<?> k2Clazz = k2.forClass(); if (k1Clazz != null && k2Clazz != null) { return k1Clazz == k2Clazz; } else if (CompactJavaSerializer.equals(k1.getObjectStreamClass(), k2.getObjectStreamClass())) { if (k1Clazz != null) { k2.setClass(k1Clazz); } else if (k2Clazz != null) { k1.setClass(k2Clazz); } return true; } else { return false; } } private static boolean equals(ObjectStreamClass osc1, ObjectStreamClass osc2) { if (osc1 == osc2) { return true; } else if (osc1.getName().equals(osc2.getName()) && osc1.getSerialVersionUID() == osc2.getSerialVersionUID() && osc1.getFields().length == osc2.getFields().length) { try { return Arrays.equals(getSerializedForm(osc1), getSerializedForm(osc2)); } catch (IOException e) { throw new AssertionError(e); } } else { return false; } } private static ObjectStreamClass disconnect(ObjectStreamClass desc) { try { ObjectInputStream oin = new ObjectInputStream(new ByteArrayInputStream(getSerializedForm(desc))) { @Override protected Class<?> resolveClass(ObjectStreamClass osc) { //Our stored OSC instances should not reference classes - doing so could cause perm-gen leaks return null; } }; return (ObjectStreamClass) oin.readObject(); } catch (ClassNotFoundException | IOException e) { throw new AssertionError(e); } } private static byte[] getSerializedForm(ObjectStreamClass desc) throws IOException { ByteArrayOutputStream bout = new ByteArrayOutputStream(); try { try (ObjectOutputStream oout = new ObjectOutputStream(bout)) { oout.writeObject(desc); } } finally { bout.close(); } return bout.toByteArray(); } }