/*
 * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made
 * available under the terms of the MIT License. See the LICENSE file in the project root for more information.
 */

package com.microsoft.sqlserver.jdbc;

import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URISyntaxException;
import java.text.MessageFormat;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.logging.Level;
import javax.security.auth.kerberos.KerberosPrincipal;
import com.microsoft.aad.msal4j.IAccount;
import com.microsoft.aad.msal4j.ClientCredentialFactory;
import com.microsoft.aad.msal4j.ClientCredentialParameters;
import com.microsoft.aad.msal4j.ConfidentialClientApplication;
import com.microsoft.aad.msal4j.IAuthenticationResult;
import com.microsoft.aad.msal4j.IClientCredential;
import com.microsoft.aad.msal4j.IntegratedWindowsAuthenticationParameters;
import com.microsoft.aad.msal4j.InteractiveRequestParameters;
import com.microsoft.aad.msal4j.MsalInteractionRequiredException;
import com.microsoft.aad.msal4j.PublicClientApplication;
import com.microsoft.aad.msal4j.SilentParameters;
import com.microsoft.aad.msal4j.SystemBrowserOptions;
import com.microsoft.aad.msal4j.UserNamePasswordParameters;
import com.microsoft.sqlserver.jdbc.SQLServerConnection.ActiveDirectoryAuthentication;

import com.microsoft.sqlserver.jdbc.SQLServerConnection.SqlFedAuthInfo;


class SQLServerMSAL4JUtils {

    static final String REDIRECTURI = "http://localhost";
    private static final String SLASH_DEFAULT = "/.default";

    private static final java.util.logging.Logger logger = java.util.logging.Logger
            .getLogger("com.microsoft.sqlserver.jdbc.SQLServerMSAL4JUtils");

    static SqlFedAuthToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String user, String password,
            String authenticationString) throws SQLServerException {
        ExecutorService executorService = Executors.newSingleThreadExecutor();

        try {
            final PublicClientApplication pca = PublicClientApplication
                    .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService)
                    .authority(fedAuthInfo.stsurl).build();
            final CompletableFuture<IAuthenticationResult> future = pca.acquireToken(UserNamePasswordParameters
                    .builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), user, password.toCharArray())
                    .build());

            final IAuthenticationResult authenticationResult = future.get();
            return new SqlFedAuthToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate());
        } catch (MalformedURLException | InterruptedException e) {
            throw new SQLServerException(e.getMessage(), e);
        } catch (ExecutionException e) {
            throw getCorrectedException(e, user, authenticationString);
        } finally {
            executorService.shutdown();
        }
    }

    static SqlFedAuthToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuthInfo, String aadPrincipalID,
            String aadPrincipalSecret, String authenticationString) throws SQLServerException {
        ExecutorService executorService = Executors.newSingleThreadExecutor();
        try {
            String defaultScopeSuffix = SLASH_DEFAULT;
            String scope = fedAuthInfo.spn.endsWith(defaultScopeSuffix) ? fedAuthInfo.spn
                                                                        : fedAuthInfo.spn + defaultScopeSuffix;
            Set<String> scopes = new HashSet<>();
            scopes.add(scope);
            IClientCredential credential = ClientCredentialFactory.createFromSecret(aadPrincipalSecret);
            ConfidentialClientApplication clientApplication = ConfidentialClientApplication
                    .builder(aadPrincipalID, credential).executorService(executorService).authority(fedAuthInfo.stsurl)
                    .build();
            final CompletableFuture<IAuthenticationResult> future = clientApplication
                    .acquireToken(ClientCredentialParameters.builder(scopes).build());
            final IAuthenticationResult authenticationResult = future.get();
            return new SqlFedAuthToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate());
        } catch (MalformedURLException | InterruptedException e) {
            throw new SQLServerException(e.getMessage(), e);
        } catch (ExecutionException e) {
            throw getCorrectedException(e, aadPrincipalID, authenticationString);
        } finally {
            executorService.shutdown();
        }
    }

    static SqlFedAuthToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo,
            String authenticationString) throws SQLServerException {
        ExecutorService executorService = Executors.newSingleThreadExecutor();

        try {
            /*
             * principal name does not matter, what matters is the realm name it gets the username in
             * principal_name@realm_name format
             */
            KerberosPrincipal kerberosPrincipal = new KerberosPrincipal("username");
            String user = kerberosPrincipal.getName();

            if (logger.isLoggable(Level.FINE)) {
                logger.fine(logger.toString() + " realm name is:" + kerberosPrincipal.getRealm());
            }

            final PublicClientApplication pca = PublicClientApplication
                    .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService)
                    .authority(fedAuthInfo.stsurl).build();
            final CompletableFuture<IAuthenticationResult> future = pca
                    .acquireToken(IntegratedWindowsAuthenticationParameters
                            .builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), user).build());

            final IAuthenticationResult authenticationResult = future.get();
            return new SqlFedAuthToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate());
        } catch (InterruptedException | IOException e) {
            throw new SQLServerException(e.getMessage(), e);
        } catch (ExecutionException e) {
            throw getCorrectedException(e, "", authenticationString);
        } finally {
            executorService.shutdown();
        }
    }

    static SqlFedAuthToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAuthInfo, String user,
            String authenticationString) throws SQLServerException {
        ExecutorService executorService = Executors.newSingleThreadExecutor();

        try {
            PublicClientApplication pca = PublicClientApplication
                    .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService)
                    .setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance())
                    .authority(fedAuthInfo.stsurl).logPii((logger.isLoggable(Level.FINE))).build();

            CompletableFuture<IAuthenticationResult> future = null;
            IAuthenticationResult authenticationResult = null;

            // try to acquire token silently if user account found in cache
            try {
                Set<IAccount> accountsInCache = pca.getAccounts().join();
                if (null != accountsInCache && !accountsInCache.isEmpty() && null != user && !user.isEmpty()) {
                    IAccount account = getAccountByUsername(accountsInCache, user);
                    if (null != account) {
                        if (logger.isLoggable(Level.FINE)) {
                            logger.fine(logger.toString() + "Silent authentication for user:" + user);
                        }
                        SilentParameters silentParameters = SilentParameters
                                .builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), account).build();

                        future = pca.acquireTokenSilently(silentParameters);
                    }
                }
            } catch (MsalInteractionRequiredException e) {
                // not an error, need to get token interactively
            }

            if (null != future) {
                authenticationResult = future.get();
            } else {
                // acquire token interactively with system browser
                if (logger.isLoggable(Level.FINE)) {
                    logger.fine(logger.toString() + "Interactive authentication");
                }
                InteractiveRequestParameters parameters = InteractiveRequestParameters.builder(new URI(REDIRECTURI))
                        .systemBrowserOptions(SystemBrowserOptions.builder()
                                .htmlMessageSuccess(SQLServerResource.getResource("R_MSALAuthComplete")).build())
                        .loginHint(user).scopes(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT)).build();

                future = pca.acquireToken(parameters);
                authenticationResult = future.get();
            }

            return new SqlFedAuthToken(authenticationResult.accessToken(), authenticationResult.expiresOnDate());
        } catch (MalformedURLException | InterruptedException | URISyntaxException e) {
            throw new SQLServerException(e.getMessage(), e);
        } catch (ExecutionException e) {
            throw getCorrectedException(e, user, authenticationString);
        } finally {
            executorService.shutdown();
        }
    }

    // Helper function to return account containing user name from set of accounts, or null if no match
    private static IAccount getAccountByUsername(Set<IAccount> accounts, String username) {
        if (!accounts.isEmpty()) {
            for (IAccount account : accounts) {
                if (account.username().equals(username)) {
                    return account;
                }
            }
        }
        return null;
    }

    private static SQLServerException getCorrectedException(ExecutionException e, String user,
            String authenticationString) {
        if (logger.isLoggable(Level.SEVERE)) {
            logger.fine(logger.toString() + " MSAL exception:" + e.getMessage());
        }

        MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_MSALExecution"));
        Object[] msgArgs = {user, authenticationString};

        if (null == e.getCause() || null == e.getCause().getMessage()) {
            // The case when Future's outcome has no AuthenticationResult but Exception.
            return new SQLServerException(form.format(msgArgs), null);
        } else {
            /*
             * the cause error message uses \\n\\r which does not give correct format change it to \r\n to provide
             * correct format
             */
            String correctedErrorMessage = e.getCause().getMessage().replaceAll("\\\\r\\\\n", "\r\n");
            RuntimeException correctedAuthenticationException = new RuntimeException(correctedErrorMessage);

            /*
             * SQLServerException is caused by ExecutionException, which is caused by AuthenticationException to match
             * the exception tree before error message correction
             */
            ExecutionException correctedExecutionException = new ExecutionException(correctedAuthenticationException);

            return new SQLServerException(form.format(msgArgs), null, 0, correctedExecutionException);
        }
    }
}