/*
 * Copyright 2008-present MongoDB, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.mongodb.internal.connection;

import com.mongodb.AuthenticationMechanism;
import com.mongodb.MongoCredential;
import com.mongodb.ServerAddress;
import com.mongodb.internal.authentication.SaslPrep;
import org.bson.internal.Base64;

import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import java.io.UnsupportedEncodingException;
import java.security.InvalidKeyException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.util.HashMap;
import java.util.Random;

import static com.mongodb.AuthenticationMechanism.SCRAM_SHA_1;
import static com.mongodb.AuthenticationMechanism.SCRAM_SHA_256;
import static com.mongodb.internal.authentication.NativeAuthenticationHelper.createAuthenticationHash;
import static java.lang.String.format;

class ScramShaAuthenticator extends SaslAuthenticator {
    private final RandomStringGenerator randomStringGenerator;
    private final AuthenticationHashGenerator authenticationHashGenerator;

    private static final int MINIMUM_ITERATION_COUNT = 4096;
    private static final String GS2_HEADER = "n,,";
    private static final int RANDOM_LENGTH = 24;
    private static final byte[] INT_1 = new byte[]{0, 0, 0, 1};

    ScramShaAuthenticator(final MongoCredentialWithCache credential) {
        this(credential, new DefaultRandomStringGenerator(), getAuthenicationHashGenerator(credential.getAuthenticationMechanism()));
    }

    ScramShaAuthenticator(final MongoCredentialWithCache credential, final RandomStringGenerator randomStringGenerator) {
        this(credential, randomStringGenerator, getAuthenicationHashGenerator(credential.getAuthenticationMechanism()));
    }

    ScramShaAuthenticator(final MongoCredentialWithCache credential, final RandomStringGenerator randomStringGenerator,
                          final AuthenticationHashGenerator authenticationHashGenerator) {
        super(credential);
        this.randomStringGenerator = randomStringGenerator;
        this.authenticationHashGenerator = authenticationHashGenerator;
    }

    @Override
    public String getMechanismName() {
        AuthenticationMechanism authMechanism = getMongoCredential().getAuthenticationMechanism();
        if (authMechanism == null) {
            throw new IllegalArgumentException("Authentication mechanism cannot be null");
        }
        return authMechanism.getMechanismName();
    }

    @Override
    protected SaslClient createSaslClient(final ServerAddress serverAddress) {
        return new ScramShaSaslClient(getMongoCredentialWithCache(), randomStringGenerator, authenticationHashGenerator);
    }

    class ScramShaSaslClient implements SaslClient {

        private final MongoCredentialWithCache credential;
        private final RandomStringGenerator randomStringGenerator;
        private final AuthenticationHashGenerator authenticationHashGenerator;
        private final String hAlgorithm;
        private final String hmacAlgorithm;

        private String clientFirstMessageBare;
        private String clientNonce;

        private byte[] serverSignature;
        private int step = -1;

        ScramShaSaslClient(final MongoCredentialWithCache credential, final RandomStringGenerator randomStringGenerator,
                           final AuthenticationHashGenerator authenticationHashGenerator) {
            this.credential = credential;
            this.randomStringGenerator = randomStringGenerator;
            this.authenticationHashGenerator = authenticationHashGenerator;
            if (credential.getAuthenticationMechanism().equals(SCRAM_SHA_1)) {
                hAlgorithm = "SHA-1";
                hmacAlgorithm = "HmacSHA1";
            } else {
                hAlgorithm = "SHA-256";
                hmacAlgorithm = "HmacSHA256";
            }
        }

        public String getMechanismName() {
            return credential.getAuthenticationMechanism().getMechanismName();
        }

        public boolean hasInitialResponse() {
            return true;
        }

        public byte[] evaluateChallenge(final byte[] challenge) throws SaslException {
            step++;
            if (step == 0) {
                return computeClientFirstMessage();
            } else if (step == 1) {
                return computeClientFinalMessage(challenge);
            } else if (step == 2) {
                return validateServerSignature(challenge);
            } else {
                throw new SaslException(format("Too many steps involved in the %s negotiation.", getMechanismName()));
            }
        }

        private byte[] validateServerSignature(final byte[] challenge) throws SaslException {
            String serverResponse = encodeUTF8(challenge);
            HashMap<String, String> map = parseServerResponse(serverResponse);
            if (!MessageDigest.isEqual(decodeBase64(map.get("v")), serverSignature)) {
                throw new SaslException("Server signature was invalid.");
            }
            return challenge;
        }

        public boolean isComplete() {
            return step == 2;
        }

        public byte[] unwrap(final byte[] incoming, final int offset, final int len) {
            throw new UnsupportedOperationException("Not implemented yet!");
        }

        public byte[] wrap(final byte[] outgoing, final int offset, final int len) {
            throw new UnsupportedOperationException("Not implemented yet!");
        }

        public Object getNegotiatedProperty(final String propName) {
            throw new UnsupportedOperationException("Not implemented yet!");
        }

        public void dispose() {
            // nothing to do
        }

        private byte[] computeClientFirstMessage() throws SaslException {
            clientNonce = randomStringGenerator.generate(RANDOM_LENGTH);
            String clientFirstMessage = "n=" + getUserName() + ",r=" + clientNonce;
            clientFirstMessageBare = clientFirstMessage;
            return decodeUTF8(GS2_HEADER + clientFirstMessage);
        }

        private byte[] computeClientFinalMessage(final byte[] challenge) throws SaslException {
            String serverFirstMessage = encodeUTF8(challenge);
            HashMap<String, String> map = parseServerResponse(serverFirstMessage);
            String serverNonce = map.get("r");
            if (!serverNonce.startsWith(clientNonce)) {
                throw new SaslException("Server sent an invalid nonce.");
            }

            String salt = map.get("s");
            int iterationCount = Integer.parseInt(map.get("i"));
            if (iterationCount < MINIMUM_ITERATION_COUNT) {
                throw new SaslException("Invalid iteration count.");
            }

            String clientFinalMessageWithoutProof = "c=" + encodeBase64(GS2_HEADER) + ",r=" + serverNonce;
            String authMessage = clientFirstMessageBare + "," + serverFirstMessage + "," + clientFinalMessageWithoutProof;
            String clientFinalMessage = clientFinalMessageWithoutProof + ",p="
                    + getClientProof(getAuthenicationHash(), salt, iterationCount, authMessage);
            return decodeUTF8(clientFinalMessage);
        }

        
The client Proof:

AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof SaltedPassword := Hi(Normalize(password), salt, i) ClientKey := HMAC(SaltedPassword, "Client Key") ServerKey := HMAC(SaltedPassword, "Server Key") StoredKey := H(ClientKey) ClientSignature := HMAC(StoredKey, AuthMessage) ClientProof := ClientKey XOR ClientSignature ServerSignature := HMAC(ServerKey, AuthMessage)

/** * The client Proof: * <p> * AuthMessage := client-first-message-bare + "," + server-first-message + "," + client-final-message-without-proof * SaltedPassword := Hi(Normalize(password), salt, i) * ClientKey := HMAC(SaltedPassword, "Client Key") * ServerKey := HMAC(SaltedPassword, "Server Key") * StoredKey := H(ClientKey) * ClientSignature := HMAC(StoredKey, AuthMessage) * ClientProof := ClientKey XOR ClientSignature * ServerSignature := HMAC(ServerKey, AuthMessage) */
String getClientProof(final String password, final String salt, final int iterationCount, final String authMessage) throws SaslException { String hashedPasswordAndSalt = encodeUTF8(h(decodeUTF8(password + salt))); CacheKey cacheKey = new CacheKey(hashedPasswordAndSalt, salt, iterationCount); CacheValue cachedKeys = getMongoCredentialWithCache().getFromCache(cacheKey, CacheValue.class); if (cachedKeys == null) { byte[] saltedPassword = hi(decodeUTF8(password), decodeBase64(salt), iterationCount); byte[] clientKey = hmac(saltedPassword, "Client Key"); byte[] serverKey = hmac(saltedPassword, "Server Key"); cachedKeys = new CacheValue(clientKey, serverKey); getMongoCredentialWithCache().putInCache(cacheKey, new CacheValue(clientKey, serverKey)); } serverSignature = hmac(cachedKeys.serverKey, authMessage); byte[] storedKey = h(cachedKeys.clientKey); byte[] clientSignature = hmac(storedKey, authMessage); byte[] clientProof = xor(cachedKeys.clientKey, clientSignature); return encodeBase64(clientProof); } private byte[] decodeBase64(final String str) { return Base64.decode(str); } private byte[] decodeUTF8(final String str) throws SaslException { try { return str.getBytes("UTF-8"); } catch (UnsupportedEncodingException e) { throw new SaslException("UTF-8 is not a supported encoding.", e); } } private String encodeBase64(final String str) throws SaslException { return Base64.encode(decodeUTF8(str)); } private String encodeBase64(final byte[] bytes) { return Base64.encode(bytes); } private String encodeUTF8(final byte[] bytes) throws SaslException { try { return new String(bytes, "UTF-8"); } catch (UnsupportedEncodingException e) { throw new SaslException("UTF-8 is not a supported encoding.", e); } } private byte[] h(final byte[] data) throws SaslException { try { return MessageDigest.getInstance(hAlgorithm).digest(data); } catch (NoSuchAlgorithmException e) { throw new SaslException(format("Algorithm for '%s' could not be found.", hAlgorithm), e); } } private byte[] hi(final byte[] password, final byte[] salt, final int iterations) throws SaslException { try { SecretKeySpec key = new SecretKeySpec(password, hmacAlgorithm); Mac mac = Mac.getInstance(hmacAlgorithm); mac.init(key); mac.update(salt); mac.update(INT_1); byte[] result = mac.doFinal(); byte[] previous = null; for (int i = 1; i < iterations; i++) { mac.update(previous != null ? previous : result); previous = mac.doFinal(); xorInPlace(result, previous); } return result; } catch (NoSuchAlgorithmException e) { throw new SaslException(format("Algorithm for '%s' could not be found.", hmacAlgorithm), e); } catch (InvalidKeyException e) { throw new SaslException(format("Invalid key for %s", hmacAlgorithm), e); } } private byte[] hmac(final byte[] bytes, final String key) throws SaslException { try { Mac mac = Mac.getInstance(hmacAlgorithm); mac.init(new SecretKeySpec(bytes, hmacAlgorithm)); return mac.doFinal(decodeUTF8(key)); } catch (NoSuchAlgorithmException e) { throw new SaslException(format("Algorithm for '%s' could not be found.", hmacAlgorithm), e); } catch (InvalidKeyException e) { throw new SaslException("Could not initialize mac.", e); } }
The server provides back key value pairs using an = sign and delimited by a command. All keys are also a single character. For example: a=kg4io3,b=skljsfoiew,c=1203
/** * The server provides back key value pairs using an = sign and delimited * by a command. All keys are also a single character. * For example: a=kg4io3,b=skljsfoiew,c=1203 */
private HashMap<String, String> parseServerResponse(final String response) { HashMap<String, String> map = new HashMap<String, String>(); String[] pairs = response.split(","); for (String pair : pairs) { String[] parts = pair.split("=", 2); map.put(parts[0], parts[1]); } return map; } private String getUserName() { String userName = credential.getCredential().getUserName(); if (userName == null) { throw new IllegalArgumentException("Username can not be null"); } return userName.replace("=", "=3D").replace(",", "=2C"); } private String getAuthenicationHash() { String password = authenticationHashGenerator.generate(credential.getCredential()); if (credential.getAuthenticationMechanism() == SCRAM_SHA_256) { password = SaslPrep.saslPrepStored(password); } return password; } private byte[] xorInPlace(final byte[] a, final byte[] b) { for (int i = 0; i < a.length; i++) { a[i] ^= b[i]; } return a; } private byte[] xor(final byte[] a, final byte[] b) { byte[] result = new byte[a.length]; System.arraycopy(a, 0, result, 0, a.length); return xorInPlace(result, b); } } public interface RandomStringGenerator { String generate(int length); } public interface AuthenticationHashGenerator { String generate(MongoCredential credential); } private static class DefaultRandomStringGenerator implements RandomStringGenerator { public String generate(final int length) { Random random = new SecureRandom(); int comma = 44; int low = 33; int high = 126; int range = high - low; char[] text = new char[length]; for (int i = 0; i < length; i++) { int next = random.nextInt(range) + low; while (next == comma) { next = random.nextInt(range) + low; } text[i] = (char) next; } return new String(text); } } private static final AuthenticationHashGenerator DEFAULT_AUTHENTICATION_HASH_GENERATOR = new AuthenticationHashGenerator() { @Override public String generate(final MongoCredential credential) { char[] password = credential.getPassword(); if (password == null) { throw new IllegalArgumentException("Password must not be null"); } return new String(password); } }; private static final AuthenticationHashGenerator LEGACY_AUTHENTICATION_HASH_GENERATOR = new AuthenticationHashGenerator() { @Override public String generate(final MongoCredential credential) { // Username and password must not be modified going into the hash. String username = credential.getUserName(); char[] password = credential.getPassword(); if (username == null || password == null) { throw new IllegalArgumentException("Username and password must not be null"); } return createAuthenticationHash(username, password); } }; private static AuthenticationHashGenerator getAuthenicationHashGenerator(final AuthenticationMechanism authenticationMechanism) { return authenticationMechanism == SCRAM_SHA_1 ? LEGACY_AUTHENTICATION_HASH_GENERATOR : DEFAULT_AUTHENTICATION_HASH_GENERATOR; } private static class CacheKey { private final String hashedPasswordAndSalt; private final String salt; private final int iterationCount; CacheKey(final String hashedPasswordAndSalt, final String salt, final int iterationCount) { this.hashedPasswordAndSalt = hashedPasswordAndSalt; this.salt = salt; this.iterationCount = iterationCount; } @Override public boolean equals(final Object o) { if (this == o) { return true; } if (o == null || getClass() != o.getClass()) { return false; } CacheKey that = (CacheKey) o; if (iterationCount != that.iterationCount) { return false; } if (!hashedPasswordAndSalt.equals(that.hashedPasswordAndSalt)) { return false; } return salt.equals(that.salt); } @Override public int hashCode() { int result = hashedPasswordAndSalt.hashCode(); result = 31 * result + salt.hashCode(); result = 31 * result + iterationCount; return result; } } private static class CacheValue { private byte[] clientKey; private byte[] serverKey; CacheValue(final byte[] clientKey, final byte[] serverKey) { this.clientKey = clientKey; this.serverKey = serverKey; } } }