package com.microsoft.azure.credentials;
import com.microsoft.azure.AzureEnvironment;
import com.microsoft.azure.management.apigeneration.Beta;
import com.microsoft.azure.serializer.AzureJacksonAdapter;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.net.HttpURLConnection;
import java.net.URL;
import java.net.URLEncoder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Calendar;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@Beta
public class MSICredentials extends AzureTokenCredentials {
private final List<Integer> retrySlots = new ArrayList<>(Arrays.asList(new Integer[] {1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181, 6765}));
private int maxRetry = retrySlots.size();
private int customTimeout = -1;
private final Lock lock = new ReentrantLock();
private final ConcurrentHashMap<String, MSIToken> cache = new ConcurrentHashMap<>();
private final String resource;
private int msiPort = 50342;
private final MSITokenSource tokenSource;
private final AzureJacksonAdapter adapter = new AzureJacksonAdapter();
private String objectId;
private String clientId;
private String identityId;
public MSICredentials() {
this(AzureEnvironment.AZURE);
}
public MSICredentials(AzureEnvironment environment) {
super(environment, null );
this.resource = environment.managementEndpoint();
this.tokenSource = MSITokenSource.IMDS_ENDPOINT;
}
@Deprecated()
public MSICredentials(AzureEnvironment environment, int msiPort) {
super(environment, null );
this.resource = environment.managementEndpoint();
this.msiPort = msiPort;
this.tokenSource = MSITokenSource.MSI_EXTENSION;
}
@Beta
public MSICredentials withObjectId(String objectId) {
this.objectId = objectId;
this.clientId = null;
this.identityId = null;
return this;
}
@Beta
public MSICredentials withClientId(String clientId) {
this.clientId = clientId;
this.objectId = null;
this.identityId = null;
return this;
}
@Beta
public MSICredentials withIdentityId(String identityId) {
this.identityId = identityId;
this.clientId = null;
this.objectId = null;
return this;
}
public MSICredentials withCustomTimeout(int timeoutInMs) {
this.customTimeout = timeoutInMs;
return this;
}
@Override
public String getToken(String tokenAudience) throws IOException {
if (this.tokenSource == MSITokenSource.MSI_EXTENSION) {
return this.getTokenFromMSIExtension(tokenAudience == null ? this.resource : tokenAudience);
} else {
return this.getTokenFromIMDSEndpoint(tokenAudience == null ? this.resource : tokenAudience);
}
}
private String getTokenFromMSIExtension(String tokenAudience) throws IOException {
URL url = new URL(String.format("http://localhost:%d/oauth2/token", this.msiPort));
String postData = String.format("resource=%s", tokenAudience);
if (this.objectId != null) {
postData += String.format("&object_id=%s", this.objectId);
} else if (this.clientId != null) {
postData += String.format("&client_id=%s", this.clientId);
} else if (this.identityId != null) {
postData += String.format("&msi_res_id=%s", this.identityId);
}
HttpURLConnection connection = null;
try {
connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("POST");
connection.setRequestProperty("Content-Type", "application/x-www-form-urlencoded; charset=utf-8");
connection.setRequestProperty("Metadata", "true");
connection.setRequestProperty("Content-Length", Integer.toString(postData.length()));
connection.setDoOutput(true);
connection.connect();
OutputStreamWriter wr = new OutputStreamWriter(connection.getOutputStream());
wr.write(postData);
wr.flush();
InputStream stream = connection.getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(stream, "UTF-8"), 100);
String result = reader.readLine();
MSIToken msiToken = adapter.deserialize(result, MSIToken.class);
return msiToken.accessToken();
} catch (Exception e) {
e.printStackTrace();
throw e;
} finally {
if (connection != null) {
connection.disconnect();
}
}
}
private String getTokenFromIMDSEndpoint(String tokenAudience) {
MSIToken token = cache.get(tokenAudience);
if (token != null && !token.isExpired()) {
return token.accessToken();
}
lock.lock();
try {
token = cache.get(tokenAudience);
if (token != null && !token.isExpired()) {
return token.accessToken();
}
try {
token = retrieveTokenFromIDMSWithRetry(tokenAudience);
if (token != null) {
cache.put(tokenAudience, token);
}
} catch (IOException exception) {
throw new RuntimeException(exception);
}
return token.accessToken();
} finally {
lock.unlock();
}
}
private MSIToken retrieveTokenFromIDMSWithRetry(String tokenAudience) throws IOException {
StringBuilder payload = new StringBuilder();
final int imdsUpgradeTimeInMs = 70 * 1000;
boolean hasTimedout = false;
try {
payload.append("api-version");
payload.append("=");
payload.append(URLEncoder.encode("2018-02-01", "UTF-8"));
payload.append("&");
payload.append("resource");
payload.append("=");
payload.append(URLEncoder.encode(tokenAudience, "UTF-8"));
if (this.objectId != null) {
payload.append("&");
payload.append("object_id");
payload.append("=");
payload.append(URLEncoder.encode(this.objectId, "UTF-8"));
} else if (this.clientId != null) {
payload.append("&");
payload.append("client_id");
payload.append("=");
payload.append(URLEncoder.encode(this.clientId, "UTF-8"));
} else if (this.identityId != null) {
payload.append("&");
payload.append("msi_res_id");
payload.append("=");
payload.append(URLEncoder.encode(this.identityId, "UTF-8"));
}
} catch (IOException exception) {
throw new RuntimeException(exception);
}
int retry = 1;
while (retry <= maxRetry) {
URL url = new URL(String.format("http://169.254.169.254/metadata/identity/oauth2/token?%s", payload.toString()));
HttpURLConnection connection = null;
long startTime = Calendar.getInstance().getTime().getTime();
try {
connection = (HttpURLConnection) url.openConnection();
connection.setRequestMethod("GET");
connection.setRequestProperty("Metadata", "true");
connection.connect();
InputStream stream = connection.getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(stream, "UTF-8"), 100);
String result = reader.readLine();
return adapter.deserialize(result, MSIToken.class);
} catch (Exception exception) {
int responseCode = connection.getResponseCode();
if (responseCode == 410 || responseCode == 429 || responseCode == 404 || (responseCode >= 500 && responseCode <= 599)) {
if (hasTimedout) {
throw new RuntimeException("Couldn't acquire access token from IMDS within the specified timeout : " + this.customTimeout + " milliseconds");
}
int retryTimeoutInMs = retrySlots.get(new Random().nextInt(retry)) * 1000;
retryTimeoutInMs = (responseCode == 410 && retryTimeoutInMs < imdsUpgradeTimeInMs) ? imdsUpgradeTimeInMs : retryTimeoutInMs;
retry++;
if (retry > maxRetry) {
break;
} else {
hasTimedout = sleep(retryTimeoutInMs, startTime);
}
} else {
throw new RuntimeException("Couldn't acquire access token from IMDS, verify your objectId, clientId or msiResourceId", exception);
}
} finally {
if (connection != null) {
connection.disconnect();
}
}
}
if (retry > maxRetry) {
throw new RuntimeException(String.format("MSI: Failed to acquire tokens after retrying %s times", maxRetry));
}
return null;
}
private boolean sleep(int timeToWaitinMs, long startTime) {
long timeToSleep = 0;
if (this.customTimeout > -1) {
long timeRemainingToTimeout = (startTime + this.customTimeout - Calendar.getInstance().getTime().getTime());
timeRemainingToTimeout = (timeToWaitinMs < timeRemainingToTimeout) ? timeToWaitinMs : timeRemainingToTimeout;
timeToSleep = (timeRemainingToTimeout > 0) ? timeRemainingToTimeout : 0;
} else {
timeToSleep = timeToWaitinMs;
}
sleep(timeToSleep);
return (timeToSleep != timeToWaitinMs);
}
private static void sleep(long millis) {
try {
Thread.sleep(millis);
} catch (InterruptedException ex) {
throw new RuntimeException(ex);
}
}
public int maxRetry() {
return maxRetry;
}
public void setMaxRetry(int maxRetry) {
this.maxRetry = maxRetry;
}
private enum MSITokenSource {
MSI_EXTENSION,
IMDS_ENDPOINT
}
}