package org.jruby.runtime.opto;
import java.lang.invoke.SwitchPoint;
import java.util.List;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.concurrent.atomic.AtomicStampedReference;
public class FailoverSwitchPointInvalidator implements Invalidator {
private static final SwitchPoint DUMMY = new SwitchPoint();
static {SwitchPoint.invalidateAll(new SwitchPoint[]{DUMMY});}
private final AtomicStampedReference<SwitchPoint> switchPoint = new AtomicStampedReference<>(DUMMY, 0);
private final int maxFailures;
public FailoverSwitchPointInvalidator(int maxFailures) {
this.maxFailures = maxFailures;
}
public void invalidate() {
SwitchPoint switchPoint = this.switchPoint.getReference();
if (switchPoint == DUMMY) return;
int failures = this.switchPoint.getStamp();
SwitchPoint newSwitch = DUMMY;
if (failures < maxFailures) {
newSwitch = new SwitchPoint();
}
this.switchPoint.compareAndSet(switchPoint, newSwitch, failures, failures + 1);
SwitchPoint.invalidateAll(new SwitchPoint[]{switchPoint});
}
public void invalidateAll(List<Invalidator> invalidators) {
SwitchPoint[] switchPoints = new SwitchPoint[invalidators.size()];
for (int i = 0; i < invalidators.size(); i++) {
Invalidator invalidator = invalidators.get(i);
assert invalidator instanceof FailoverSwitchPointInvalidator;
switchPoints[i] = ((FailoverSwitchPointInvalidator)invalidator).replaceSwitchPoint();
}
SwitchPoint.invalidateAll(switchPoints);
}
public synchronized Object getData() {
while (true) {
SwitchPoint switchPoint = this.switchPoint.getReference();
int failures = this.switchPoint.getStamp();
if (switchPoint == DUMMY && failures <= maxFailures) {
SwitchPoint newSwitch = new SwitchPoint();
if (this.switchPoint.compareAndSet(DUMMY, newSwitch, failures, failures)) {
return newSwitch;
}
} else {
return switchPoint;
}
}
}
public synchronized SwitchPoint replaceSwitchPoint() {
while (true) {
SwitchPoint switchPoint = this.switchPoint.getReference();
int failures = this.switchPoint.getStamp();
if (switchPoint == DUMMY || failures > maxFailures) return DUMMY;
SwitchPoint newSwitch = new SwitchPoint();
if (this.switchPoint.compareAndSet(switchPoint, newSwitch, failures, failures + 1)) {
return newSwitch;
}
}
}
}