package org.apache.logging.log4j.util;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InvalidObjectException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.StreamCorruptedException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.Collection;
import java.util.ConcurrentModificationException;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import org.apache.logging.log4j.status.StatusLogger;
public class SortedArrayStringMap implements IndexedStringMap {
private static final int DEFAULT_INITIAL_CAPACITY = 4;
private static final long serialVersionUID = -5748905872274478116L;
private static final int HASHVAL = 31;
private static final TriConsumer<String, Object, StringMap> PUT_ALL = new TriConsumer<String, Object, StringMap>() {
@Override
public void accept(final String key, final Object value, final StringMap contextData) {
contextData.putValue(key, value);
}
};
private static final String[] EMPTY = {};
private static final String FROZEN = "Frozen collection cannot be modified";
private transient String[] keys = EMPTY;
private transient Object[] values = EMPTY;
private transient int size;
private static final Method setObjectInputFilter;
private static final Method getObjectInputFilter;
private static final Method newObjectInputFilter;
static {
Method[] methods = ObjectInputStream.class.getMethods();
Method setMethod = null;
Method getMethod = null;
for (final Method method : methods) {
if (method.getName().equals("setObjectInputFilter")) {
setMethod = method;
} else if (method.getName().equals("getObjectInputFilter")) {
getMethod = method;
}
}
Method newMethod = null;
try {
if (setMethod != null) {
final Class<?> clazz = Class.forName("org.apache.logging.log4j.util.internal.DefaultObjectInputFilter");
methods = clazz.getMethods();
for (final Method method : methods) {
if (method.getName().equals("newInstance") && Modifier.isStatic(method.getModifiers())) {
newMethod = method;
break;
}
}
}
} catch (final ClassNotFoundException ex) {
}
newObjectInputFilter = newMethod;
setObjectInputFilter = setMethod;
getObjectInputFilter = getMethod;
}
private int threshold;
private boolean immutable;
private transient boolean iterating;
public SortedArrayStringMap() {
this(DEFAULT_INITIAL_CAPACITY);
}
public SortedArrayStringMap(final int initialCapacity) {
if (initialCapacity < 0) {
throw new IllegalArgumentException("Initial capacity must be at least zero but was " + initialCapacity);
}
threshold = ceilingNextPowerOfTwo(initialCapacity == 0 ? 1 : initialCapacity);
}
public SortedArrayStringMap(final ReadOnlyStringMap other) {
if (other instanceof SortedArrayStringMap) {
initFrom0((SortedArrayStringMap) other);
} else if (other != null) {
resize(ceilingNextPowerOfTwo(other.size()));
other.forEach(PUT_ALL, this);
}
}
public SortedArrayStringMap(final Map<String, ?> map) {
resize(ceilingNextPowerOfTwo(map.size()));
for (final Map.Entry<String, ?> entry : map.entrySet()) {
putValue(entry.getKey(), entry.getValue());
}
}
private void assertNotFrozen() {
if (immutable) {
throw new UnsupportedOperationException(FROZEN);
}
}
private void assertNoConcurrentModification() {
if (iterating) {
throw new ConcurrentModificationException();
}
}
@Override
public void clear() {
if (keys == EMPTY) {
return;
}
assertNotFrozen();
assertNoConcurrentModification();
Arrays.fill(keys, 0, size, null);
Arrays.fill(values, 0, size, null);
size = 0;
}
@Override
public boolean containsKey(final String key) {
return indexOfKey(key) >= 0;
}
@Override
public Map<String, String> toMap() {
final Map<String, String> result = new HashMap<>(size());
for (int i = 0; i < size(); i++) {
final Object value = getValueAt(i);
result.put(getKeyAt(i), value == null ? null : String.valueOf(value));
}
return result;
}
@Override
public void freeze() {
immutable = true;
}
@Override
public boolean isFrozen() {
return immutable;
}
@SuppressWarnings("unchecked")
@Override
public <V> V getValue(final String key) {
final int index = indexOfKey(key);
if (index < 0) {
return null;
}
return (V) values[index];
}
@Override
public boolean isEmpty() {
return size == 0;
}
@Override
public int indexOfKey(final String key) {
if (keys == EMPTY) {
return -1;
}
if (key == null) {
return nullKeyIndex();
}
final int start = size > 0 && keys[0] == null ? 1 : 0;
return Arrays.binarySearch(keys, start, size, key);
}
private int nullKeyIndex() {
return size > 0 && keys[0] == null ? 0 : ~0;
}
@Override
public void putValue(final String key, final Object value) {
assertNotFrozen();
assertNoConcurrentModification();
if (keys == EMPTY) {
inflateTable(threshold);
}
final int index = indexOfKey(key);
if (index >= 0) {
keys[index] = key;
values[index] = value;
} else {
insertAt(~index, key, value);
}
}
private void insertAt(final int index, final String key, final Object value) {
ensureCapacity();
System.arraycopy(keys, index, keys, index + 1, size - index);
System.arraycopy(values, index, values, index + 1, size - index);
keys[index] = key;
values[index] = value;
size++;
}
@Override
public void putAll(final ReadOnlyStringMap source) {
if (source == this || source == null || source.isEmpty()) {
return;
}
assertNotFrozen();
assertNoConcurrentModification();
if (source instanceof SortedArrayStringMap) {
if (this.size == 0) {
initFrom0((SortedArrayStringMap) source);
} else {
merge((SortedArrayStringMap) source);
}
} else if (source != null) {
source.forEach(PUT_ALL, this);
}
}
private void initFrom0(final SortedArrayStringMap other) {
if (keys.length < other.size) {
keys = new String[other.threshold];
values = new Object[other.threshold];
}
System.arraycopy(other.keys, 0, keys, 0, other.size);
System.arraycopy(other.values, 0, values, 0, other.size);
size = other.size;
threshold = other.threshold;
}
private void merge(final SortedArrayStringMap other) {
final String[] myKeys = keys;
final Object[] myVals = values;
final int newSize = other.size + this.size;
threshold = ceilingNextPowerOfTwo(newSize);
if (keys.length < threshold) {
keys = new String[threshold];
values = new Object[threshold];
}
boolean overwrite = true;
if (other.size() > size()) {
System.arraycopy(myKeys, 0, keys, other.size, this.size);
System.arraycopy(myVals, 0, values, other.size, this.size);
System.arraycopy(other.keys, 0, keys, 0, other.size);
System.arraycopy(other.values, 0, values, 0, other.size);
size = other.size;
overwrite = false;
} else {
System.arraycopy(myKeys, 0, keys, 0, this.size);
System.arraycopy(myVals, 0, values, 0, this.size);
System.arraycopy(other.keys, 0, keys, this.size, other.size);
System.arraycopy(other.values, 0, values, this.size, other.size);
}
for (int i = size; i < newSize; i++) {
final int index = indexOfKey(keys[i]);
if (index < 0) {
insertAt(~index, keys[i], values[i]);
} else if (overwrite) {
keys[index] = keys[i];
values[index] = values[i];
}
}
Arrays.fill(keys, size, newSize, null);
Arrays.fill(values, size, newSize, null);
}
private void ensureCapacity() {
if (size >= threshold) {
resize(threshold * 2);
}
}
private void resize(final int newCapacity) {
final String[] oldKeys = keys;
final Object[] oldValues = values;
keys = new String[newCapacity];
values = new Object[newCapacity];
System.arraycopy(oldKeys, 0, keys, 0, size);
System.arraycopy(oldValues, 0, values, 0, size);
threshold = newCapacity;
}
private void inflateTable(final int toSize) {
threshold = toSize;
keys = new String[toSize];
values = new Object[toSize];
}
@Override
public void remove(final String key) {
if (keys == EMPTY) {
return;
}
final int index = indexOfKey(key);
if (index >= 0) {
assertNotFrozen();
assertNoConcurrentModification();
System.arraycopy(keys, index + 1, keys, index, size - 1 - index);
System.arraycopy(values, index + 1, values, index, size - 1 - index);
keys[size - 1] = null;
values[size - 1] = null;
size--;
}
}
@Override
public String getKeyAt(final int index) {
if (index < 0 || index >= size) {
return null;
}
return keys[index];
}
@SuppressWarnings("unchecked")
@Override
public <V> V getValueAt(final int index) {
if (index < 0 || index >= size) {
return null;
}
return (V) values[index];
}
@Override
public int size() {
return size;
}
@SuppressWarnings("unchecked")
@Override
public <V> void forEach(final BiConsumer<String, ? super V> action) {
iterating = true;
try {
for (int i = 0; i < size; i++) {
action.accept(keys[i], (V) values[i]);
}
} finally {
iterating = false;
}
}
@SuppressWarnings("unchecked")
@Override
public <V, T> void forEach(final TriConsumer<String, ? super V, T> action, final T state) {
iterating = true;
try {
for (int i = 0; i < size; i++) {
action.accept(keys[i], (V) values[i], state);
}
} finally {
iterating = false;
}
}
@Override
public boolean equals(final Object obj) {
if (obj == this) {
return true;
}
if (!(obj instanceof SortedArrayStringMap)) {
return false;
}
final SortedArrayStringMap other = (SortedArrayStringMap) obj;
if (this.size() != other.size()) {
return false;
}
for (int i = 0; i < size(); i++) {
if (!Objects.equals(keys[i], other.keys[i])) {
return false;
}
if (!Objects.equals(values[i], other.values[i])) {
return false;
}
}
return true;
}
@Override
public int hashCode() {
int result = 37;
result = HASHVAL * result + size;
result = HASHVAL * result + hashCode(keys, size);
result = HASHVAL * result + hashCode(values, size);
return result;
}
private static int hashCode(final Object[] values, final int length) {
int result = 1;
for (int i = 0; i < length; i++) {
result = HASHVAL * result + (values[i] == null ? 0 : values[i].hashCode());
}
return result;
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder(256);
sb.append('{');
for (int i = 0; i < size; i++) {
if (i > 0) {
sb.append(", ");
}
sb.append(keys[i]).append('=');
sb.append(values[i] == this ? "(this map)" : values[i]);
}
sb.append('}');
return sb.toString();
}
private void writeObject(final java.io.ObjectOutputStream s) throws IOException {
s.defaultWriteObject();
if (keys == EMPTY) {
s.writeInt(ceilingNextPowerOfTwo(threshold));
} else {
s.writeInt(keys.length);
}
s.writeInt(size);
if (size > 0) {
for (int i = 0; i < size; i++) {
s.writeObject(keys[i]);
try {
s.writeObject(marshall(values[i]));
} catch (final Exception e) {
handleSerializationException(e, i, keys[i]);
s.writeObject(null);
}
}
}
}
private static byte[] marshall(final Object obj) throws IOException {
if (obj == null) {
return null;
}
final ByteArrayOutputStream bout = new ByteArrayOutputStream();
try (ObjectOutputStream oos = new ObjectOutputStream(bout)) {
oos.writeObject(obj);
oos.flush();
return bout.toByteArray();
}
}
private static Object unmarshall(final byte[] data, final ObjectInputStream inputStream)
throws IOException, ClassNotFoundException {
final ByteArrayInputStream bin = new ByteArrayInputStream(data);
Collection<String> allowedClasses = null;
ObjectInputStream ois;
if (inputStream instanceof FilteredObjectInputStream) {
allowedClasses = ((FilteredObjectInputStream) inputStream).getAllowedClasses();
ois = new FilteredObjectInputStream(bin, allowedClasses);
} else {
try {
final Object obj = getObjectInputFilter.invoke(inputStream);
final Object filter = newObjectInputFilter.invoke(null, obj);
ois = new ObjectInputStream(bin);
setObjectInputFilter.invoke(ois, filter);
} catch (IllegalAccessException | InvocationTargetException ex) {
throw new StreamCorruptedException("Unable to set ObjectInputFilter on stream");
}
}
try {
return ois.readObject();
} finally {
ois.close();
}
}
private static int ceilingNextPowerOfTwo(final int x) {
final int BITS_PER_INT = 32;
return 1 << (BITS_PER_INT - Integer.numberOfLeadingZeros(x - 1));
}
private void readObject(final java.io.ObjectInputStream s) throws IOException, ClassNotFoundException {
if (!(s instanceof FilteredObjectInputStream) && setObjectInputFilter == null) {
throw new IllegalArgumentException("readObject requires a FilteredObjectInputStream or an ObjectInputStream that accepts an ObjectInputFilter");
}
s.defaultReadObject();
keys = EMPTY;
values = EMPTY;
final int capacity = s.readInt();
if (capacity < 0) {
throw new InvalidObjectException("Illegal capacity: " + capacity);
}
final int mappings = s.readInt();
if (mappings < 0) {
throw new InvalidObjectException("Illegal mappings count: " + mappings);
}
if (mappings > 0) {
inflateTable(capacity);
} else {
threshold = capacity;
}
for (int i = 0; i < mappings; i++) {
keys[i] = (String) s.readObject();
try {
final byte[] marshalledObject = (byte[]) s.readObject();
values[i] = marshalledObject == null ? null : unmarshall(marshalledObject, s);
} catch (final Exception | LinkageError error) {
handleSerializationException(error, i, keys[i]);
values[i] = null;
}
}
size = mappings;
}
private void handleSerializationException(final Throwable t, final int i, final String key) {
StatusLogger.getLogger().warn("Ignoring {} for key[{}] ('{}')", String.valueOf(t), i, keys[i]);
}
}