package com.mongodb.internal.connection;
import com.mongodb.MongoBulkWriteException;
import com.mongodb.ServerAddress;
import com.mongodb.WriteConcern;
import com.mongodb.bulk.BulkWriteError;
import com.mongodb.bulk.BulkWriteInsert;
import com.mongodb.bulk.BulkWriteResult;
import com.mongodb.bulk.BulkWriteUpsert;
import com.mongodb.bulk.WriteConcernError;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.TreeSet;
import static com.mongodb.assertions.Assertions.notNull;
import static java.util.Collections.singletonList;
import static java.util.Comparator.comparingInt;
public class BulkWriteBatchCombiner {
private final ServerAddress serverAddress;
private final boolean ordered;
private final WriteConcern writeConcern;
private int insertedCount;
private int matchedCount;
private int deletedCount;
private int modifiedCount = 0;
private final Set<BulkWriteUpsert> writeUpserts = new TreeSet<>(comparingInt(BulkWriteUpsert::getIndex));
private final Set<BulkWriteInsert> writeInserts = new TreeSet<>(comparingInt(BulkWriteInsert::getIndex));
private final Set<BulkWriteError> writeErrors = new TreeSet<>(comparingInt(BulkWriteError::getIndex));
private final Set<String> errorLabels = new HashSet<>();
private final List<WriteConcernError> writeConcernErrors = new ArrayList<>();
public BulkWriteBatchCombiner(final ServerAddress serverAddress, final boolean ordered, final WriteConcern writeConcern) {
this.writeConcern = notNull("writeConcern", writeConcern);
this.ordered = ordered;
this.serverAddress = notNull("serverAddress", serverAddress);
}
public void addResult(final BulkWriteResult result) {
insertedCount += result.getInsertedCount();
matchedCount += result.getMatchedCount();
deletedCount += result.getDeletedCount();
modifiedCount += result.getModifiedCount();
writeUpserts.addAll(result.getUpserts());
writeInserts.addAll(result.getInserts());
}
public void addErrorResult(final MongoBulkWriteException exception, final IndexMap indexMap) {
addResult(exception.getWriteResult());
errorLabels.addAll(exception.getErrorLabels());
mergeWriteErrors(exception.getWriteErrors(), indexMap);
mergeWriteConcernError(exception.getWriteConcernError());
}
public void addWriteErrorResult(final BulkWriteError writeError, final IndexMap indexMap) {
notNull("writeError", writeError);
mergeWriteErrors(singletonList(writeError), indexMap);
}
public void addWriteConcernErrorResult(final WriteConcernError writeConcernError) {
notNull("writeConcernError", writeConcernError);
mergeWriteConcernError(writeConcernError);
}
public void addErrorResult(final List<BulkWriteError> writeErrors,
final WriteConcernError writeConcernError, final IndexMap indexMap) {
mergeWriteErrors(writeErrors, indexMap);
mergeWriteConcernError(writeConcernError);
}
public BulkWriteResult getResult() {
throwOnError();
return createResult();
}
public boolean shouldStopSendingMoreBatches() {
return ordered && hasWriteErrors();
}
public boolean hasErrors() {
return hasWriteErrors() || hasWriteConcernErrors();
}
public MongoBulkWriteException getError() {
if (!hasErrors()) {
return null;
}
return new MongoBulkWriteException(createResult(), new ArrayList<>(writeErrors),
writeConcernErrors.isEmpty() ? null : writeConcernErrors.get(writeConcernErrors.size() - 1),
serverAddress, errorLabels);
}
private void mergeWriteConcernError(final WriteConcernError writeConcernError) {
if (writeConcernError != null) {
if (writeConcernErrors.isEmpty()) {
writeConcernErrors.add(writeConcernError);
errorLabels.addAll(writeConcernError.getErrorLabels());
} else if (!writeConcernError.equals(writeConcernErrors.get(writeConcernErrors.size() - 1))) {
writeConcernErrors.add(writeConcernError);
errorLabels.addAll(writeConcernError.getErrorLabels());
}
}
}
private void mergeWriteErrors(final List<BulkWriteError> newWriteErrors, final IndexMap indexMap) {
for (BulkWriteError cur : newWriteErrors) {
writeErrors.add(new BulkWriteError(cur.getCode(), cur.getMessage(), cur.getDetails(), indexMap.map(cur.getIndex())));
}
}
private void throwOnError() {
if (hasErrors()) {
throw getError();
}
}
private BulkWriteResult createResult() {
return writeConcern.isAcknowledged()
? BulkWriteResult.acknowledged(insertedCount, matchedCount, deletedCount, modifiedCount,
new ArrayList<>(writeUpserts), new ArrayList<>(writeInserts))
: BulkWriteResult.unacknowledged();
}
private boolean hasWriteErrors() {
return !writeErrors.isEmpty();
}
private boolean hasWriteConcernErrors() {
return !writeConcernErrors.isEmpty();
}
}