package org.hibernate.engine.jdbc.batch.internal;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.Map;
import org.hibernate.HibernateException;
import org.hibernate.engine.jdbc.batch.spi.BatchKey;
import org.hibernate.engine.jdbc.spi.JdbcCoordinator;
import org.hibernate.internal.CoreMessageLogger;
import org.jboss.logging.Logger;
public class BatchingBatch extends AbstractBatchImpl {
private static final CoreMessageLogger LOG = Logger.getMessageLogger(
CoreMessageLogger.class,
BatchingBatch.class.getName()
);
private final int batchSize;
private int batchPosition;
private boolean batchExecuted;
private int statementPosition;
public BatchingBatch(
BatchKey key,
JdbcCoordinator jdbcCoordinator,
int batchSize) {
super( key, jdbcCoordinator );
if ( ! key.getExpectation().canBeBatched() ) {
throw new HibernateException( "attempting to batch an operation which cannot be batched" );
}
this.batchSize = batchSize;
}
private String currentStatementSql;
private PreparedStatement currentStatement;
@Override
public PreparedStatement getBatchStatement(String sql, boolean callable) {
currentStatementSql = sql;
currentStatement = super.getBatchStatement( sql, callable );
return currentStatement;
}
@Override
public void addToBatch() {
try {
currentStatement.addBatch();
}
catch ( SQLException e ) {
LOG.debugf( "SQLException escaped proxy", e );
throw sqlExceptionHelper().convert( e, "could not perform addBatch", currentStatementSql );
}
statementPosition++;
if ( statementPosition >= getKey().getBatchedStatementCount() ) {
batchPosition++;
if ( batchPosition == batchSize ) {
notifyObserversImplicitExecution();
performExecution();
batchPosition = 0;
batchExecuted = true;
}
statementPosition = 0;
}
}
@Override
protected void doExecuteBatch() {
if (batchPosition == 0 ) {
if(! batchExecuted) {
LOG.debug( "No batched statements to execute" );
}
}
else {
performExecution();
}
}
private void performExecution() {
LOG.debugf( "Executing batch size: %s", batchPosition );
try {
for ( Map.Entry<String,PreparedStatement> entry : getStatements().entrySet() ) {
try {
final PreparedStatement statement = entry.getValue();
final int[] rowCounts;
try {
transactionContext().startBatchExecution();
rowCounts = statement.executeBatch();
}
finally {
transactionContext().endBatchExecution();
}
checkRowCounts( rowCounts, statement );
}
catch ( SQLException e ) {
abortBatch();
throw sqlExceptionHelper().convert( e, "could not execute batch", entry.getKey() );
}
}
}
catch ( RuntimeException re ) {
LOG.unableToExecuteBatch( re.getMessage() );
throw re;
}
finally {
batchPosition = 0;
}
}
private void checkRowCounts(int[] rowCounts, PreparedStatement ps) throws SQLException, HibernateException {
final int numberOfRowCounts = rowCounts.length;
if ( numberOfRowCounts != batchPosition ) {
LOG.unexpectedRowCounts();
}
for ( int i = 0; i < numberOfRowCounts; i++ ) {
getKey().getExpectation().verifyOutcome( rowCounts[i], ps, i );
}
}
}