package org.apache.lucene.queries.intervals;
import java.util.Objects;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.similarities.Similarity;
abstract class IntervalScoreFunction {
static IntervalScoreFunction saturationFunction(float pivot) {
if (pivot <= 0 || Float.isFinite(pivot) == false) {
throw new IllegalArgumentException("pivot must be > 0, got: " + pivot);
}
return new SaturationFunction(pivot);
}
static IntervalScoreFunction sigmoidFunction(float pivot, float exp) {
if (pivot <= 0 || Float.isFinite(pivot) == false) {
throw new IllegalArgumentException("pivot must be > 0, got: " + pivot);
}
if (exp <= 0 || Float.isFinite(exp) == false) {
throw new IllegalArgumentException("exp must be > 0, got: " + exp);
}
return new SigmoidFunction(pivot, exp);
}
public abstract Similarity.SimScorer scorer(float weight);
public abstract Explanation explain(String interval, float weight, float sloppyFreq);
@Override
public abstract boolean equals(Object other);
@Override
public abstract int hashCode();
@Override
public abstract String toString();
private static class SaturationFunction extends IntervalScoreFunction {
final float pivot;
private SaturationFunction(float pivot) {
this.pivot = pivot;
}
@Override
public Similarity.SimScorer scorer(float weight) {
return new Similarity.SimScorer() {
@Override
public float score(float freq, long norm) {
return weight * (1.0f - pivot / (pivot + freq));
}
};
}
@Override
public Explanation explain(String interval, float weight, float sloppyFreq) {
float score = scorer(weight).score(sloppyFreq, 1L);
return Explanation.match(score,
"Saturation function on interval frequency, computed as w * S / (S + k) from:",
Explanation.match(weight, "w, weight of this function"),
Explanation.match(pivot, "k, pivot feature value that would give a score contribution equal to w/2"),
Explanation.match(sloppyFreq, "S, the sloppy frequency of the interval " + interval));
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SaturationFunction that = (SaturationFunction) o;
return Float.compare(that.pivot, pivot) == 0;
}
@Override
public int hashCode() {
return Objects.hash(pivot);
}
@Override
public String toString() {
return "SaturationFunction(pivot=" + pivot + ")";
}
}
private static class SigmoidFunction extends IntervalScoreFunction {
private final float pivot, a;
private final double pivotPa;
private SigmoidFunction(float pivot, float a) {
this.pivot = pivot;
this.a = a;
this.pivotPa = Math.pow(pivot, a);
}
@Override
public Similarity.SimScorer scorer(float weight) {
return new Similarity.SimScorer() {
@Override
public float score(float freq, long norm) {
return (float) (weight * (1.0f - pivotPa / (Math.pow(freq, a) + pivotPa)));
}
};
}
@Override
public Explanation explain(String interval, float weight, float sloppyFreq) {
float score = scorer(weight).score(sloppyFreq, 1L);
return Explanation.match(score,
"Sigmoid function on interval frequency, computed as w * S^a / (S^a + k^a) from:",
Explanation.match(weight, "w, weight of this function"),
Explanation.match(pivot, "k, pivot feature value that would give a score contribution equal to w/2"),
Explanation.match(a, "a, exponent, higher values make the function grow slower before k and faster after k"),
Explanation.match(sloppyFreq, "S, the sloppy frequency of the interval " + interval));
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
SigmoidFunction that = (SigmoidFunction) o;
return Float.compare(that.pivot, pivot) == 0 &&
Float.compare(that.a, a) == 0;
}
@Override
public int hashCode() {
return Objects.hash(pivot, a);
}
@Override
public String toString() {
return "SigmoidFunction(pivot=" + pivot + ", a=" + a + ")";
}
}
}