package com.mongodb.internal.connection;
import com.mongodb.MongoCommandException;
import com.mongodb.MongoException;
import com.mongodb.MongoNodeIsRecoveringException;
import com.mongodb.MongoNotPrimaryException;
import com.mongodb.MongoSecurityException;
import com.mongodb.MongoSocketException;
import com.mongodb.MongoSocketReadTimeoutException;
import com.mongodb.connection.ClusterConnectionMode;
import com.mongodb.connection.ServerDescription;
import com.mongodb.connection.ServerId;
import com.mongodb.connection.TopologyVersion;
import com.mongodb.diagnostics.logging.Logger;
import com.mongodb.diagnostics.logging.Loggers;
import com.mongodb.event.CommandListener;
import com.mongodb.event.ServerClosedEvent;
import com.mongodb.event.ServerDescriptionChangedEvent;
import com.mongodb.event.ServerListener;
import com.mongodb.event.ServerOpeningEvent;
import com.mongodb.internal.async.SingleResultCallback;
import com.mongodb.internal.session.SessionContext;
import java.util.List;
import static com.mongodb.assertions.Assertions.isTrue;
import static com.mongodb.assertions.Assertions.notNull;
import static com.mongodb.connection.ServerConnectionState.CONNECTING;
import static com.mongodb.internal.async.ErrorHandlingResultCallback.errorHandlingCallback;
import static com.mongodb.internal.connection.ClusterableServer.ConnectionState.AFTER_HANDSHAKE;
import static com.mongodb.internal.connection.ClusterableServer.ConnectionState.BEFORE_HANDSHAKE;
import static com.mongodb.internal.operation.ServerVersionHelper.FOUR_DOT_TWO_WIRE_VERSION;
import static java.util.Arrays.asList;
class DefaultServer implements ClusterableServer {
private static final Logger LOGGER = Loggers.getLogger("connection");
private static final List<Integer> SHUTDOWN_CODES = asList(91, 11600);
private final ServerId serverId;
private final ConnectionPool connectionPool;
private final ClusterConnectionMode clusterConnectionMode;
private final ConnectionFactory connectionFactory;
private final ServerMonitor serverMonitor;
private final ChangeListener<ServerDescription> serverStateListener;
private final ServerListener serverListener;
private final CommandListener commandListener;
private final ClusterClock clusterClock;
private volatile ServerDescription description;
private volatile boolean isClosed;
DefaultServer(final ServerId serverId, final ClusterConnectionMode clusterConnectionMode, final ConnectionPool connectionPool,
final ConnectionFactory connectionFactory, final ServerMonitorFactory serverMonitorFactory,
final ServerListener serverListener, final CommandListener commandListener, final ClusterClock clusterClock) {
this.serverListener = notNull("serverListener", serverListener);
this.commandListener = commandListener;
this.clusterClock = notNull("clusterClock", clusterClock);
notNull("serverAddress", serverId);
notNull("serverMonitorFactory", serverMonitorFactory);
this.clusterConnectionMode = notNull("clusterConnectionMode", clusterConnectionMode);
this.connectionFactory = notNull("connectionFactory", connectionFactory);
this.connectionPool = notNull("connectionPool", connectionPool);
this.serverStateListener = new DefaultServerStateListener();
this.serverId = serverId;
serverListener.serverOpening(new ServerOpeningEvent(this.serverId));
description = ServerDescription.builder().state(CONNECTING).address(serverId.getAddress()).build();
serverMonitor = serverMonitorFactory.create(serverStateListener);
serverMonitor.start();
}
@Override
public Connection getConnection() {
isTrue("open", !isClosed());
try {
return connectionFactory.create(connectionPool.get(), new DefaultServerProtocolExecutor(), clusterConnectionMode);
} catch (MongoSecurityException e) {
connectionPool.invalidate();
throw e;
} catch (MongoSocketException e) {
invalidate(ConnectionState.BEFORE_HANDSHAKE, e, connectionPool.getGeneration(), description.getMaxWireVersion());
throw e;
}
}
@Override
public void getConnectionAsync(final SingleResultCallback<AsyncConnection> callback) {
isTrue("open", !isClosed());
connectionPool.getAsync(new SingleResultCallback<InternalConnection>() {
@Override
public void onResult(final InternalConnection result, final Throwable t) {
if (t instanceof MongoSecurityException) {
connectionPool.invalidate();
} else if (t instanceof MongoSocketException) {
invalidate(ConnectionState.BEFORE_HANDSHAKE, t, connectionPool.getGeneration(), description.getMaxWireVersion());
}
if (t != null) {
callback.onResult(null, t);
} else {
callback.onResult(connectionFactory.createAsync(result, new DefaultServerProtocolExecutor(), clusterConnectionMode),
null);
}
}
});
}
@Override
public ServerDescription getDescription() {
isTrue("open", !isClosed());
return description;
}
@Override
public void resetToConnecting() {
serverStateListener.stateChanged(new ChangeEvent<>(description, ServerDescription.builder()
.state(CONNECTING).address(serverId.getAddress()).build()));
}
@Override
public synchronized void invalidate() {
if (!isClosed()) {
serverStateListener.stateChanged(new ChangeEvent<>(description, ServerDescription.builder()
.state(CONNECTING).address(serverId.getAddress()).build()));
connect();
if (description.getMaxWireVersion() < FOUR_DOT_TWO_WIRE_VERSION) {
connectionPool.invalidate();
}
}
}
@Override
public synchronized void invalidate(final ConnectionState connectionState, final Throwable t, final int connectionGeneration,
final int maxWireVersion) {
if (!isClosed()) {
if (connectionGeneration < connectionPool.getGeneration()) {
return;
}
if (t instanceof MongoSocketException
&& (!(t instanceof MongoSocketReadTimeoutException) || connectionState == BEFORE_HANDSHAKE)) {
serverStateListener.stateChanged(new ChangeEvent<>(description, ServerDescription.builder()
.state(CONNECTING).address(serverId.getAddress()).exception(t).build()));
connectionPool.invalidate();
serverMonitor.cancelCurrentCheck();
} else if (t instanceof MongoNotPrimaryException || t instanceof MongoNodeIsRecoveringException) {
if (isStale(((MongoCommandException) t))) {
return;
}
serverStateListener.stateChanged(new ChangeEvent<ServerDescription>(description, ServerDescription.builder()
.state(CONNECTING).address(serverId.getAddress()).exception(t).build()));
connect();
if (maxWireVersion < FOUR_DOT_TWO_WIRE_VERSION || SHUTDOWN_CODES.contains(((MongoCommandException) t).getErrorCode())) {
connectionPool.invalidate();
}
}
}
}
private boolean isStale(final MongoCommandException t) {
if (!t.getResponse().containsKey("topologyVersion")) {
return false;
}
return isStale(description.getTopologyVersion(), new TopologyVersion(t.getResponse().getDocument("topologyVersion")));
}
private boolean isStale(final TopologyVersion currentTopologyVersion, final TopologyVersion candidateTopologyVersion) {
if (candidateTopologyVersion == null || currentTopologyVersion == null) {
return false;
}
if (!candidateTopologyVersion.getProcessId().equals(currentTopologyVersion.getProcessId())) {
return false;
}
if (candidateTopologyVersion.getCounter() <= currentTopologyVersion.getCounter()) {
return true;
}
return false;
}
@Override
public void close() {
if (!isClosed()) {
connectionPool.close();
serverMonitor.close();
isClosed = true;
serverListener.serverClosed(new ServerClosedEvent(serverId));
}
}
@Override
public boolean isClosed() {
return isClosed;
}
@Override
public void connect() {
serverMonitor.connect();
}
ConnectionPool getConnectionPool() {
return connectionPool;
}
private class DefaultServerProtocolExecutor implements ProtocolExecutor {
@Override
public <T> T execute(final LegacyProtocol<T> protocol, final InternalConnection connection) {
try {
protocol.setCommandListener(commandListener);
return protocol.execute(connection);
} catch (MongoException e) {
invalidate(AFTER_HANDSHAKE, e, connection.getGeneration(), connection.getDescription().getMaxWireVersion());
throw e;
}
}
@Override
public <T> void executeAsync(final LegacyProtocol<T> protocol, final InternalConnection connection,
final SingleResultCallback<T> callback) {
protocol.setCommandListener(commandListener);
protocol.executeAsync(connection, errorHandlingCallback(new SingleResultCallback<T>() {
@Override
public void onResult(final T result, final Throwable t) {
if (t != null) {
invalidate(AFTER_HANDSHAKE, t, connection.getGeneration(), connection.getDescription().getMaxWireVersion());
}
callback.onResult(result, t);
}
}, LOGGER));
}
@SuppressWarnings("unchecked")
@Override
public <T> T execute(final CommandProtocol<T> protocol, final InternalConnection connection,
final SessionContext sessionContext) {
try {
protocol.sessionContext(new ClusterClockAdvancingSessionContext(sessionContext, clusterClock));
return protocol.execute(connection);
} catch (MongoWriteConcernWithResponseException e) {
invalidate();
return (T) e.getResponse();
} catch (MongoException e) {
invalidate(AFTER_HANDSHAKE, e, connection.getGeneration(), connection.getDescription().getMaxWireVersion());
if (e instanceof MongoSocketException && sessionContext.hasSession()) {
sessionContext.markSessionDirty();
}
throw e;
}
}
@SuppressWarnings("unchecked")
@Override
public <T> void executeAsync(final CommandProtocol<T> protocol, final InternalConnection connection,
final SessionContext sessionContext, final SingleResultCallback<T> callback) {
protocol.sessionContext(new ClusterClockAdvancingSessionContext(sessionContext, clusterClock));
protocol.executeAsync(connection, errorHandlingCallback(new SingleResultCallback<T>() {
@Override
public void onResult(final T result, final Throwable t) {
if (t != null) {
if (t instanceof MongoWriteConcernWithResponseException) {
invalidate();
callback.onResult((T) ((MongoWriteConcernWithResponseException) t).getResponse(), null);
} else {
invalidate(AFTER_HANDSHAKE, t, connection.getGeneration(), connection.getDescription().getMaxWireVersion());
if (t instanceof MongoSocketException && sessionContext.hasSession()) {
sessionContext.markSessionDirty();
}
callback.onResult(null, t);
}
} else {
callback.onResult(result, null);
}
}
}, LOGGER));
}
}
private final class DefaultServerStateListener implements ChangeListener<ServerDescription> {
@Override
public void stateChanged(final ChangeEvent<ServerDescription> event) {
ServerDescription oldDescription = description;
if (shouldReplace(oldDescription, event.getNewValue())) {
description = event.getNewValue();
serverListener.serverDescriptionChanged(new ServerDescriptionChangedEvent(serverId, description, oldDescription));
}
}
private boolean shouldReplace(final ServerDescription oldDescription, final ServerDescription newDescription) {
TopologyVersion oldTopologyVersion = oldDescription.getTopologyVersion();
TopologyVersion newTopologyVersion = newDescription.getTopologyVersion();
if (newTopologyVersion == null || oldTopologyVersion == null) {
return true;
}
if (!newTopologyVersion.getProcessId().equals(oldTopologyVersion.getProcessId())) {
return true;
}
if (newTopologyVersion.getCounter() >= oldTopologyVersion.getCounter()) {
return true;
}
return false;
}
}
}