package io.undertow.protocols.ssl;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.net.ssl.SNIHostName;
import javax.net.ssl.SNIMatcher;
import javax.net.ssl.SNIServerName;
import javax.net.ssl.SSLContext;
import io.undertow.UndertowMessages;
public class SNIContextMatcher {
private final SSLContext defaultContext;
private final Map<SNIMatcher, SSLContext> wildcards;
private final Map<SNIMatcher, SSLContext> exacts;
SNIContextMatcher(SSLContext defaultContext, Map<SNIMatcher, SSLContext> wildcards, Map<SNIMatcher, SSLContext> exacts) {
this.defaultContext = defaultContext;
this.wildcards = wildcards;
this.exacts = exacts;
}
public SSLContext getContext(List<SNIServerName> servers) {
for (Map.Entry<SNIMatcher, SSLContext> entry : exacts.entrySet()) {
for (SNIServerName server : servers) {
if (entry.getKey().matches(server)) {
return entry.getValue();
}
}
}
for (Map.Entry<SNIMatcher, SSLContext> entry : wildcards.entrySet()) {
for (SNIServerName server : servers) {
if (entry.getKey().matches(server)) {
return entry.getValue();
}
}
}
return defaultContext;
}
public SSLContext getDefaultContext() {
return defaultContext;
}
public static class Builder {
private SSLContext defaultContext;
private final Map<SNIMatcher, SSLContext> wildcards = new LinkedHashMap<>();
private final Map<SNIMatcher, SSLContext> exacts = new LinkedHashMap<>();
public SNIContextMatcher build() {
if(defaultContext == null) {
throw UndertowMessages.MESSAGES.defaultContextCannotBeNull();
}
return new SNIContextMatcher(defaultContext, wildcards, exacts);
}
public SSLContext getDefaultContext() {
return defaultContext;
}
public Builder setDefaultContext(SSLContext defaultContext) {
this.defaultContext = defaultContext;
return this;
}
public Builder addMatch(String name, SSLContext context) {
if (name.contains("*")) {
wildcards.put(SNIHostName.createSNIMatcher(name), context);
} else {
exacts.put(SNIHostName.createSNIMatcher(name), context);
}
return this;
}
}
}